Spaces:
Running
Running
Joseph Pollack
commited on
adds the initial iterative and deep research workflows
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +3 -0
- .pre-commit-config.yaml +1 -0
- AGENTS.md +0 -118
- CLAUDE.md +0 -111
- GEMINI.md +0 -98
- docs/CONFIGURATION.md +291 -0
- docs/architecture/graph_orchestration.md +141 -0
- docs/examples/writer_agents_usage.md +415 -0
- docs/implementation/02_phase_search.md +31 -19
- pyproject.toml +11 -0
- src/agent_factory/agents.py +339 -0
- src/agent_factory/graph_builder.py +608 -0
- src/agent_factory/judges.py +9 -0
- src/agents/input_parser.py +178 -0
- src/agents/judge_agent.py +1 -1
- src/agents/knowledge_gap.py +156 -0
- src/agents/long_writer.py +431 -0
- src/agents/proofreader.py +205 -0
- src/agents/search_agent.py +1 -1
- src/agents/state.py +27 -5
- src/agents/thinking.py +148 -0
- src/agents/tool_selector.py +168 -0
- src/agents/writer.py +209 -0
- src/{orchestrator.py → legacy_orchestrator.py} +0 -0
- src/middleware/__init__.py +33 -0
- src/middleware/budget_tracker.py +390 -0
- src/middleware/state_machine.py +129 -0
- src/middleware/workflow_manager.py +322 -0
- src/orchestrator/__init__.py +48 -0
- src/orchestrator/graph_orchestrator.py +953 -0
- src/orchestrator/planner_agent.py +174 -0
- src/orchestrator/research_flow.py +999 -0
- src/orchestrator_factory.py +1 -1
- src/tools/__init__.py +8 -1
- src/tools/crawl_adapter.py +58 -0
- src/tools/rag_tool.py +183 -0
- src/tools/search_handler.py +67 -5
- src/tools/tool_executor.py +193 -0
- src/tools/web_search_adapter.py +63 -0
- src/utils/citation_validator.py +91 -0
- src/utils/config.py +98 -0
- src/utils/models.py +267 -1
- tests/integration/test_deep_research.py +352 -0
- tests/integration/test_middleware_integration.py +245 -0
- tests/integration/test_parallel_loops_judge.py +396 -0
- tests/integration/test_rag_integration.py +343 -0
- tests/integration/test_research_flows.py +584 -0
- tests/unit/agent_factory/test_graph_builder.py +439 -0
- tests/unit/agents/test_input_parser.py +325 -0
- tests/unit/agents/test_long_writer.py +509 -0
.gitignore
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Python
|
| 2 |
__pycache__/
|
| 3 |
*.py[cod]
|
|
|
|
| 1 |
+
folder/
|
| 2 |
+
.cursor/
|
| 3 |
+
.ruff_cache/
|
| 4 |
# Python
|
| 5 |
__pycache__/
|
| 6 |
*.py[cod]
|
.pre-commit-config.yaml
CHANGED
|
@@ -13,6 +13,7 @@ repos:
|
|
| 13 |
hooks:
|
| 14 |
- id: mypy
|
| 15 |
files: ^src/
|
|
|
|
| 16 |
additional_dependencies:
|
| 17 |
- pydantic>=2.7
|
| 18 |
- pydantic-settings>=2.2
|
|
|
|
| 13 |
hooks:
|
| 14 |
- id: mypy
|
| 15 |
files: ^src/
|
| 16 |
+
exclude: ^folder
|
| 17 |
additional_dependencies:
|
| 18 |
- pydantic>=2.7
|
| 19 |
- pydantic-settings>=2.2
|
AGENTS.md
DELETED
|
@@ -1,118 +0,0 @@
|
|
| 1 |
-
# AGENTS.md
|
| 2 |
-
|
| 3 |
-
This file provides guidance to AI agents when working with code in this repository.
|
| 4 |
-
|
| 5 |
-
## Project Overview
|
| 6 |
-
|
| 7 |
-
DeepCritical is an AI-native drug repurposing research agent for a HuggingFace hackathon. It uses a search-and-judge loop to autonomously search biomedical databases (PubMed, ClinicalTrials.gov, bioRxiv) and synthesize evidence for queries like "What existing drugs might help treat long COVID fatigue?".
|
| 8 |
-
|
| 9 |
-
**Current Status:** Phases 1-13 COMPLETE (Foundation through Modal sandbox integration).
|
| 10 |
-
|
| 11 |
-
## Development Commands
|
| 12 |
-
|
| 13 |
-
```bash
|
| 14 |
-
# Install all dependencies (including dev)
|
| 15 |
-
make install # or: uv sync --all-extras && uv run pre-commit install
|
| 16 |
-
|
| 17 |
-
# Run all quality checks (lint + typecheck + test) - MUST PASS BEFORE COMMIT
|
| 18 |
-
make check
|
| 19 |
-
|
| 20 |
-
# Individual commands
|
| 21 |
-
make test # uv run pytest tests/unit/ -v
|
| 22 |
-
make lint # uv run ruff check src tests
|
| 23 |
-
make format # uv run ruff format src tests
|
| 24 |
-
make typecheck # uv run mypy src
|
| 25 |
-
make test-cov # uv run pytest --cov=src --cov-report=term-missing
|
| 26 |
-
|
| 27 |
-
# Run single test
|
| 28 |
-
uv run pytest tests/unit/utils/test_config.py::TestSettings::test_default_max_iterations -v
|
| 29 |
-
|
| 30 |
-
# Integration tests (real APIs)
|
| 31 |
-
uv run pytest -m integration
|
| 32 |
-
```
|
| 33 |
-
|
| 34 |
-
## Architecture
|
| 35 |
-
|
| 36 |
-
**Pattern**: Search-and-judge loop with multi-tool orchestration.
|
| 37 |
-
|
| 38 |
-
```text
|
| 39 |
-
User Question → Orchestrator
|
| 40 |
-
↓
|
| 41 |
-
Search Loop:
|
| 42 |
-
1. Query PubMed, ClinicalTrials.gov, bioRxiv
|
| 43 |
-
2. Gather evidence
|
| 44 |
-
3. Judge quality ("Do we have enough?")
|
| 45 |
-
4. If NO → Refine query, search more
|
| 46 |
-
5. If YES → Synthesize findings (+ optional Modal analysis)
|
| 47 |
-
↓
|
| 48 |
-
Research Report with Citations
|
| 49 |
-
```
|
| 50 |
-
|
| 51 |
-
**Key Components**:
|
| 52 |
-
|
| 53 |
-
- `src/orchestrator.py` - Main agent loop
|
| 54 |
-
- `src/tools/pubmed.py` - PubMed E-utilities search
|
| 55 |
-
- `src/tools/clinicaltrials.py` - ClinicalTrials.gov API
|
| 56 |
-
- `src/tools/biorxiv.py` - bioRxiv/medRxiv preprint search
|
| 57 |
-
- `src/tools/code_execution.py` - Modal sandbox execution
|
| 58 |
-
- `src/tools/search_handler.py` - Scatter-gather orchestration
|
| 59 |
-
- `src/services/embeddings.py` - Semantic search & deduplication (ChromaDB)
|
| 60 |
-
- `src/services/statistical_analyzer.py` - Statistical analysis via Modal
|
| 61 |
-
- `src/agent_factory/judges.py` - LLM-based evidence assessment
|
| 62 |
-
- `src/agents/` - Magentic multi-agent mode (SearchAgent, JudgeAgent, etc.)
|
| 63 |
-
- `src/mcp_tools.py` - MCP tool wrappers for Claude Desktop
|
| 64 |
-
- `src/utils/config.py` - Pydantic Settings (loads from `.env`)
|
| 65 |
-
- `src/utils/models.py` - Evidence, Citation, SearchResult models
|
| 66 |
-
- `src/utils/exceptions.py` - Exception hierarchy
|
| 67 |
-
- `src/app.py` - Gradio UI with MCP server (HuggingFace Spaces)
|
| 68 |
-
|
| 69 |
-
**Break Conditions**: Judge approval, token budget (50K max), or max iterations (default 10).
|
| 70 |
-
|
| 71 |
-
## Configuration
|
| 72 |
-
|
| 73 |
-
Settings via pydantic-settings from `.env`:
|
| 74 |
-
|
| 75 |
-
- `LLM_PROVIDER`: "openai" or "anthropic"
|
| 76 |
-
- `OPENAI_API_KEY` / `ANTHROPIC_API_KEY`: LLM keys
|
| 77 |
-
- `NCBI_API_KEY`: Optional, for higher PubMed rate limits
|
| 78 |
-
- `MODAL_TOKEN_ID` / `MODAL_TOKEN_SECRET`: For Modal sandbox (optional)
|
| 79 |
-
- `MAX_ITERATIONS`: 1-50, default 10
|
| 80 |
-
- `LOG_LEVEL`: DEBUG, INFO, WARNING, ERROR
|
| 81 |
-
|
| 82 |
-
## Exception Hierarchy
|
| 83 |
-
|
| 84 |
-
```text
|
| 85 |
-
DeepCriticalError (base)
|
| 86 |
-
├── SearchError
|
| 87 |
-
│ └── RateLimitError
|
| 88 |
-
├── JudgeError
|
| 89 |
-
└── ConfigurationError
|
| 90 |
-
```
|
| 91 |
-
|
| 92 |
-
## Testing
|
| 93 |
-
|
| 94 |
-
- **TDD**: Write tests first in `tests/unit/`, implement in `src/`
|
| 95 |
-
- **Markers**: `unit`, `integration`, `slow`
|
| 96 |
-
- **Mocking**: `respx` for httpx, `pytest-mock` for general mocking
|
| 97 |
-
- **Fixtures**: `tests/conftest.py` has `mock_httpx_client`, `mock_llm_response`
|
| 98 |
-
|
| 99 |
-
## Coding Standards
|
| 100 |
-
|
| 101 |
-
- Python 3.11+, strict mypy, ruff (100-char lines)
|
| 102 |
-
- Type all functions, use Pydantic models for data
|
| 103 |
-
- Use `structlog` for logging, not print
|
| 104 |
-
- Conventional commits: `feat(scope):`, `fix:`, `docs:`
|
| 105 |
-
|
| 106 |
-
## Git Workflow
|
| 107 |
-
|
| 108 |
-
- `main`: Production-ready (GitHub)
|
| 109 |
-
- `dev`: Development integration (GitHub)
|
| 110 |
-
- Remote `origin`: GitHub (source of truth for PRs/code review)
|
| 111 |
-
- Remote `huggingface-upstream`: HuggingFace Spaces (deployment target)
|
| 112 |
-
|
| 113 |
-
**HuggingFace Spaces Collaboration:**
|
| 114 |
-
|
| 115 |
-
- Each contributor should use their own dev branch: `yourname-dev` (e.g., `vcms-dev`, `mario-dev`)
|
| 116 |
-
- **DO NOT push directly to `main` or `dev` on HuggingFace** - these can be overwritten easily
|
| 117 |
-
- GitHub is the source of truth; HuggingFace is for deployment/demo
|
| 118 |
-
- Consider using git hooks to prevent accidental pushes to protected branches
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CLAUDE.md
DELETED
|
@@ -1,111 +0,0 @@
|
|
| 1 |
-
# CLAUDE.md
|
| 2 |
-
|
| 3 |
-
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
| 4 |
-
|
| 5 |
-
## Project Overview
|
| 6 |
-
|
| 7 |
-
DeepCritical is an AI-native drug repurposing research agent for a HuggingFace hackathon. It uses a search-and-judge loop to autonomously search biomedical databases (PubMed, ClinicalTrials.gov, bioRxiv) and synthesize evidence for queries like "What existing drugs might help treat long COVID fatigue?".
|
| 8 |
-
|
| 9 |
-
**Current Status:** Phases 1-13 COMPLETE (Foundation through Modal sandbox integration).
|
| 10 |
-
|
| 11 |
-
## Development Commands
|
| 12 |
-
|
| 13 |
-
```bash
|
| 14 |
-
# Install all dependencies (including dev)
|
| 15 |
-
make install # or: uv sync --all-extras && uv run pre-commit install
|
| 16 |
-
|
| 17 |
-
# Run all quality checks (lint + typecheck + test) - MUST PASS BEFORE COMMIT
|
| 18 |
-
make check
|
| 19 |
-
|
| 20 |
-
# Individual commands
|
| 21 |
-
make test # uv run pytest tests/unit/ -v
|
| 22 |
-
make lint # uv run ruff check src tests
|
| 23 |
-
make format # uv run ruff format src tests
|
| 24 |
-
make typecheck # uv run mypy src
|
| 25 |
-
make test-cov # uv run pytest --cov=src --cov-report=term-missing
|
| 26 |
-
|
| 27 |
-
# Run single test
|
| 28 |
-
uv run pytest tests/unit/utils/test_config.py::TestSettings::test_default_max_iterations -v
|
| 29 |
-
|
| 30 |
-
# Integration tests (real APIs)
|
| 31 |
-
uv run pytest -m integration
|
| 32 |
-
```
|
| 33 |
-
|
| 34 |
-
## Architecture
|
| 35 |
-
|
| 36 |
-
**Pattern**: Search-and-judge loop with multi-tool orchestration.
|
| 37 |
-
|
| 38 |
-
```text
|
| 39 |
-
User Question → Orchestrator
|
| 40 |
-
↓
|
| 41 |
-
Search Loop:
|
| 42 |
-
1. Query PubMed, ClinicalTrials.gov, bioRxiv
|
| 43 |
-
2. Gather evidence
|
| 44 |
-
3. Judge quality ("Do we have enough?")
|
| 45 |
-
4. If NO → Refine query, search more
|
| 46 |
-
5. If YES → Synthesize findings (+ optional Modal analysis)
|
| 47 |
-
↓
|
| 48 |
-
Research Report with Citations
|
| 49 |
-
```
|
| 50 |
-
|
| 51 |
-
**Key Components**:
|
| 52 |
-
|
| 53 |
-
- `src/orchestrator.py` - Main agent loop
|
| 54 |
-
- `src/tools/pubmed.py` - PubMed E-utilities search
|
| 55 |
-
- `src/tools/clinicaltrials.py` - ClinicalTrials.gov API
|
| 56 |
-
- `src/tools/biorxiv.py` - bioRxiv/medRxiv preprint search
|
| 57 |
-
- `src/tools/code_execution.py` - Modal sandbox execution
|
| 58 |
-
- `src/tools/search_handler.py` - Scatter-gather orchestration
|
| 59 |
-
- `src/services/embeddings.py` - Semantic search & deduplication (ChromaDB)
|
| 60 |
-
- `src/services/statistical_analyzer.py` - Statistical analysis via Modal
|
| 61 |
-
- `src/agent_factory/judges.py` - LLM-based evidence assessment
|
| 62 |
-
- `src/agents/` - Magentic multi-agent mode (SearchAgent, JudgeAgent, etc.)
|
| 63 |
-
- `src/mcp_tools.py` - MCP tool wrappers for Claude Desktop
|
| 64 |
-
- `src/utils/config.py` - Pydantic Settings (loads from `.env`)
|
| 65 |
-
- `src/utils/models.py` - Evidence, Citation, SearchResult models
|
| 66 |
-
- `src/utils/exceptions.py` - Exception hierarchy
|
| 67 |
-
- `src/app.py` - Gradio UI with MCP server (HuggingFace Spaces)
|
| 68 |
-
|
| 69 |
-
**Break Conditions**: Judge approval, token budget (50K max), or max iterations (default 10).
|
| 70 |
-
|
| 71 |
-
## Configuration
|
| 72 |
-
|
| 73 |
-
Settings via pydantic-settings from `.env`:
|
| 74 |
-
|
| 75 |
-
- `LLM_PROVIDER`: "openai" or "anthropic"
|
| 76 |
-
- `OPENAI_API_KEY` / `ANTHROPIC_API_KEY`: LLM keys
|
| 77 |
-
- `NCBI_API_KEY`: Optional, for higher PubMed rate limits
|
| 78 |
-
- `MODAL_TOKEN_ID` / `MODAL_TOKEN_SECRET`: For Modal sandbox (optional)
|
| 79 |
-
- `MAX_ITERATIONS`: 1-50, default 10
|
| 80 |
-
- `LOG_LEVEL`: DEBUG, INFO, WARNING, ERROR
|
| 81 |
-
|
| 82 |
-
## Exception Hierarchy
|
| 83 |
-
|
| 84 |
-
```text
|
| 85 |
-
DeepCriticalError (base)
|
| 86 |
-
├── SearchError
|
| 87 |
-
│ └── RateLimitError
|
| 88 |
-
├── JudgeError
|
| 89 |
-
└── ConfigurationError
|
| 90 |
-
```
|
| 91 |
-
|
| 92 |
-
## Testing
|
| 93 |
-
|
| 94 |
-
- **TDD**: Write tests first in `tests/unit/`, implement in `src/`
|
| 95 |
-
- **Markers**: `unit`, `integration`, `slow`
|
| 96 |
-
- **Mocking**: `respx` for httpx, `pytest-mock` for general mocking
|
| 97 |
-
- **Fixtures**: `tests/conftest.py` has `mock_httpx_client`, `mock_llm_response`
|
| 98 |
-
|
| 99 |
-
## Git Workflow
|
| 100 |
-
|
| 101 |
-
- `main`: Production-ready (GitHub)
|
| 102 |
-
- `dev`: Development integration (GitHub)
|
| 103 |
-
- Remote `origin`: GitHub (source of truth for PRs/code review)
|
| 104 |
-
- Remote `huggingface-upstream`: HuggingFace Spaces (deployment target)
|
| 105 |
-
|
| 106 |
-
**HuggingFace Spaces Collaboration:**
|
| 107 |
-
|
| 108 |
-
- Each contributor should use their own dev branch: `yourname-dev` (e.g., `vcms-dev`, `mario-dev`)
|
| 109 |
-
- **DO NOT push directly to `main` or `dev` on HuggingFace** - these can be overwritten easily
|
| 110 |
-
- GitHub is the source of truth; HuggingFace is for deployment/demo
|
| 111 |
-
- Consider using git hooks to prevent accidental pushes to protected branches
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GEMINI.md
DELETED
|
@@ -1,98 +0,0 @@
|
|
| 1 |
-
# DeepCritical Context
|
| 2 |
-
|
| 3 |
-
## Project Overview
|
| 4 |
-
|
| 5 |
-
**DeepCritical** is an AI-native Medical Drug Repurposing Research Agent.
|
| 6 |
-
**Goal:** To accelerate the discovery of new uses for existing drugs by intelligently searching biomedical literature (PubMed, ClinicalTrials.gov, bioRxiv), evaluating evidence, and hypothesizing potential applications.
|
| 7 |
-
|
| 8 |
-
**Architecture:**
|
| 9 |
-
The project follows a **Vertical Slice Architecture** (Search -> Judge -> Orchestrator) and adheres to **Strict TDD** (Test-Driven Development).
|
| 10 |
-
|
| 11 |
-
**Current Status:**
|
| 12 |
-
|
| 13 |
-
- **Phases 1-9:** COMPLETE. Foundation, Search, Judge, UI, Orchestrator, Embeddings, Hypothesis, Report, Cleanup.
|
| 14 |
-
- **Phases 10-11:** COMPLETE. ClinicalTrials.gov and bioRxiv integration.
|
| 15 |
-
- **Phase 12:** COMPLETE. MCP Server integration (Gradio MCP at `/gradio_api/mcp/`).
|
| 16 |
-
- **Phase 13:** COMPLETE. Modal sandbox for statistical analysis.
|
| 17 |
-
|
| 18 |
-
## Tech Stack & Tooling
|
| 19 |
-
|
| 20 |
-
- **Language:** Python 3.11 (Pinned)
|
| 21 |
-
- **Package Manager:** `uv` (Rust-based, extremely fast)
|
| 22 |
-
- **Frameworks:** `pydantic`, `pydantic-ai`, `httpx`, `gradio[mcp]`
|
| 23 |
-
- **Vector DB:** `chromadb` with `sentence-transformers` for semantic search
|
| 24 |
-
- **Code Execution:** `modal` for secure sandboxed Python execution
|
| 25 |
-
- **Testing:** `pytest`, `pytest-asyncio`, `respx` (for mocking)
|
| 26 |
-
- **Quality:** `ruff` (linting/formatting), `mypy` (strict type checking), `pre-commit`
|
| 27 |
-
|
| 28 |
-
## Building & Running
|
| 29 |
-
|
| 30 |
-
| Command | Description |
|
| 31 |
-
| :--- | :--- |
|
| 32 |
-
| `make install` | Install dependencies and pre-commit hooks. |
|
| 33 |
-
| `make test` | Run unit tests. |
|
| 34 |
-
| `make lint` | Run Ruff linter. |
|
| 35 |
-
| `make format` | Run Ruff formatter. |
|
| 36 |
-
| `make typecheck` | Run Mypy static type checker. |
|
| 37 |
-
| `make check` | **The Golden Gate:** Runs lint, typecheck, and test. Must pass before committing. |
|
| 38 |
-
| `make clean` | Clean up cache and artifacts. |
|
| 39 |
-
|
| 40 |
-
## Directory Structure
|
| 41 |
-
|
| 42 |
-
- `src/`: Source code
|
| 43 |
-
- `utils/`: Shared utilities (`config.py`, `exceptions.py`, `models.py`)
|
| 44 |
-
- `tools/`: Search tools (`pubmed.py`, `clinicaltrials.py`, `biorxiv.py`, `code_execution.py`)
|
| 45 |
-
- `services/`: Services (`embeddings.py`, `statistical_analyzer.py`)
|
| 46 |
-
- `agents/`: Magentic multi-agent mode agents
|
| 47 |
-
- `agent_factory/`: Agent definitions (judges, prompts)
|
| 48 |
-
- `mcp_tools.py`: MCP tool wrappers for Claude Desktop integration
|
| 49 |
-
- `app.py`: Gradio UI with MCP server
|
| 50 |
-
- `tests/`: Test suite
|
| 51 |
-
- `unit/`: Isolated unit tests (Mocked)
|
| 52 |
-
- `integration/`: Real API tests (Marked as slow/integration)
|
| 53 |
-
- `docs/`: Documentation and Implementation Specs
|
| 54 |
-
- `examples/`: Working demos for each phase
|
| 55 |
-
|
| 56 |
-
## Key Components
|
| 57 |
-
|
| 58 |
-
- `src/orchestrator.py` - Main agent loop
|
| 59 |
-
- `src/tools/pubmed.py` - PubMed E-utilities search
|
| 60 |
-
- `src/tools/clinicaltrials.py` - ClinicalTrials.gov API
|
| 61 |
-
- `src/tools/biorxiv.py` - bioRxiv/medRxiv preprint search
|
| 62 |
-
- `src/tools/code_execution.py` - Modal sandbox execution
|
| 63 |
-
- `src/services/statistical_analyzer.py` - Statistical analysis via Modal
|
| 64 |
-
- `src/mcp_tools.py` - MCP tool wrappers
|
| 65 |
-
- `src/app.py` - Gradio UI (HuggingFace Spaces) with MCP server
|
| 66 |
-
|
| 67 |
-
## Configuration
|
| 68 |
-
|
| 69 |
-
Settings via pydantic-settings from `.env`:
|
| 70 |
-
|
| 71 |
-
- `LLM_PROVIDER`: "openai" or "anthropic"
|
| 72 |
-
- `OPENAI_API_KEY` / `ANTHROPIC_API_KEY`: LLM keys
|
| 73 |
-
- `NCBI_API_KEY`: Optional, for higher PubMed rate limits
|
| 74 |
-
- `MODAL_TOKEN_ID` / `MODAL_TOKEN_SECRET`: For Modal sandbox (optional)
|
| 75 |
-
- `MAX_ITERATIONS`: 1-50, default 10
|
| 76 |
-
- `LOG_LEVEL`: DEBUG, INFO, WARNING, ERROR
|
| 77 |
-
|
| 78 |
-
## Development Conventions
|
| 79 |
-
|
| 80 |
-
1. **Strict TDD:** Write failing tests in `tests/unit/` *before* implementing logic in `src/`.
|
| 81 |
-
2. **Type Safety:** All code must pass `mypy --strict`. Use Pydantic models for data exchange.
|
| 82 |
-
3. **Linting:** Zero tolerance for Ruff errors.
|
| 83 |
-
4. **Mocking:** Use `respx` or `unittest.mock` for all external API calls in unit tests.
|
| 84 |
-
5. **Vertical Slices:** Implement features end-to-end rather than layer-by-layer.
|
| 85 |
-
|
| 86 |
-
## Git Workflow
|
| 87 |
-
|
| 88 |
-
- `main`: Production-ready (GitHub)
|
| 89 |
-
- `dev`: Development integration (GitHub)
|
| 90 |
-
- Remote `origin`: GitHub (source of truth for PRs/code review)
|
| 91 |
-
- Remote `huggingface-upstream`: HuggingFace Spaces (deployment target)
|
| 92 |
-
|
| 93 |
-
**HuggingFace Spaces Collaboration:**
|
| 94 |
-
|
| 95 |
-
- Each contributor should use their own dev branch: `yourname-dev` (e.g., `vcms-dev`, `mario-dev`)
|
| 96 |
-
- **DO NOT push directly to `main` or `dev` on HuggingFace** - these can be overwritten easily
|
| 97 |
-
- GitHub is the source of truth; HuggingFace is for deployment/demo
|
| 98 |
-
- Consider using git hooks to prevent accidental pushes to protected branches
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/CONFIGURATION.md
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration Guide
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
DeepCritical uses **Pydantic Settings** for centralized configuration management. All settings are defined in `src/utils/config.py` and can be configured via environment variables or a `.env` file.
|
| 6 |
+
|
| 7 |
+
## Quick Start
|
| 8 |
+
|
| 9 |
+
1. Copy the example environment file (if available) or create a `.env` file in the project root
|
| 10 |
+
2. Set at least one LLM API key (`OPENAI_API_KEY` or `ANTHROPIC_API_KEY`)
|
| 11 |
+
3. Optionally configure other services as needed
|
| 12 |
+
|
| 13 |
+
## Configuration System
|
| 14 |
+
|
| 15 |
+
### How It Works
|
| 16 |
+
|
| 17 |
+
- **Settings Class**: `Settings` class in `src/utils/config.py` extends `BaseSettings` from `pydantic_settings`
|
| 18 |
+
- **Environment File**: Automatically loads from `.env` file (if present)
|
| 19 |
+
- **Environment Variables**: Reads from environment variables (case-insensitive)
|
| 20 |
+
- **Type Safety**: Strongly-typed fields with validation
|
| 21 |
+
- **Singleton Pattern**: Global `settings` instance for easy access
|
| 22 |
+
|
| 23 |
+
### Usage
|
| 24 |
+
|
| 25 |
+
```python
|
| 26 |
+
from src.utils.config import settings
|
| 27 |
+
|
| 28 |
+
# Check if API keys are available
|
| 29 |
+
if settings.has_openai_key:
|
| 30 |
+
# Use OpenAI
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
# Access configuration values
|
| 34 |
+
max_iterations = settings.max_iterations
|
| 35 |
+
web_search_provider = settings.web_search_provider
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Required Configuration
|
| 39 |
+
|
| 40 |
+
### At Least One LLM Provider
|
| 41 |
+
|
| 42 |
+
You must configure at least one LLM provider:
|
| 43 |
+
|
| 44 |
+
**OpenAI:**
|
| 45 |
+
```bash
|
| 46 |
+
LLM_PROVIDER=openai
|
| 47 |
+
OPENAI_API_KEY=your_openai_api_key_here
|
| 48 |
+
OPENAI_MODEL=gpt-5.1
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
**Anthropic:**
|
| 52 |
+
```bash
|
| 53 |
+
LLM_PROVIDER=anthropic
|
| 54 |
+
ANTHROPIC_API_KEY=your_anthropic_api_key_here
|
| 55 |
+
ANTHROPIC_MODEL=claude-sonnet-4-5-20250929
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
## Optional Configuration
|
| 59 |
+
|
| 60 |
+
### Embedding Configuration
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
# Embedding Provider: "openai", "local", or "huggingface"
|
| 64 |
+
EMBEDDING_PROVIDER=local
|
| 65 |
+
|
| 66 |
+
# OpenAI Embedding Model (used by LlamaIndex RAG)
|
| 67 |
+
OPENAI_EMBEDDING_MODEL=text-embedding-3-small
|
| 68 |
+
|
| 69 |
+
# Local Embedding Model (sentence-transformers)
|
| 70 |
+
LOCAL_EMBEDDING_MODEL=all-MiniLM-L6-v2
|
| 71 |
+
|
| 72 |
+
# HuggingFace Embedding Model
|
| 73 |
+
HUGGINGFACE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
### HuggingFace Configuration
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
# HuggingFace API Token (for inference API)
|
| 80 |
+
HUGGINGFACE_API_KEY=your_huggingface_api_key_here
|
| 81 |
+
# Or use HF_TOKEN (alternative name)
|
| 82 |
+
|
| 83 |
+
# Default HuggingFace Model ID
|
| 84 |
+
HUGGINGFACE_MODEL=meta-llama/Llama-3.1-8B-Instruct
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### Web Search Configuration
|
| 88 |
+
|
| 89 |
+
```bash
|
| 90 |
+
# Web Search Provider: "serper", "searchxng", "brave", "tavily", or "duckduckgo"
|
| 91 |
+
# Default: "duckduckgo" (no API key required)
|
| 92 |
+
WEB_SEARCH_PROVIDER=duckduckgo
|
| 93 |
+
|
| 94 |
+
# Serper API Key (for Google search via Serper)
|
| 95 |
+
SERPER_API_KEY=your_serper_api_key_here
|
| 96 |
+
|
| 97 |
+
# SearchXNG Host URL
|
| 98 |
+
SEARCHXNG_HOST=http://localhost:8080
|
| 99 |
+
|
| 100 |
+
# Brave Search API Key
|
| 101 |
+
BRAVE_API_KEY=your_brave_api_key_here
|
| 102 |
+
|
| 103 |
+
# Tavily API Key
|
| 104 |
+
TAVILY_API_KEY=your_tavily_api_key_here
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
### PubMed Configuration
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
# NCBI API Key (optional, for higher rate limits: 10 req/sec vs 3 req/sec)
|
| 111 |
+
NCBI_API_KEY=your_ncbi_api_key_here
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
### Agent Configuration
|
| 115 |
+
|
| 116 |
+
```bash
|
| 117 |
+
# Maximum iterations per research loop
|
| 118 |
+
MAX_ITERATIONS=10
|
| 119 |
+
|
| 120 |
+
# Search timeout in seconds
|
| 121 |
+
SEARCH_TIMEOUT=30
|
| 122 |
+
|
| 123 |
+
# Use graph-based execution for research flows
|
| 124 |
+
USE_GRAPH_EXECUTION=false
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### Budget & Rate Limiting Configuration
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
# Default token budget per research loop
|
| 131 |
+
DEFAULT_TOKEN_LIMIT=100000
|
| 132 |
+
|
| 133 |
+
# Default time limit per research loop (minutes)
|
| 134 |
+
DEFAULT_TIME_LIMIT_MINUTES=10
|
| 135 |
+
|
| 136 |
+
# Default iterations limit per research loop
|
| 137 |
+
DEFAULT_ITERATIONS_LIMIT=10
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
### RAG Service Configuration
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
# ChromaDB collection name for RAG
|
| 144 |
+
RAG_COLLECTION_NAME=deepcritical_evidence
|
| 145 |
+
|
| 146 |
+
# Number of top results to retrieve from RAG
|
| 147 |
+
RAG_SIMILARITY_TOP_K=5
|
| 148 |
+
|
| 149 |
+
# Automatically ingest evidence into RAG
|
| 150 |
+
RAG_AUTO_INGEST=true
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
### ChromaDB Configuration
|
| 154 |
+
|
| 155 |
+
```bash
|
| 156 |
+
# ChromaDB storage path
|
| 157 |
+
CHROMA_DB_PATH=./chroma_db
|
| 158 |
+
|
| 159 |
+
# Whether to persist ChromaDB to disk
|
| 160 |
+
CHROMA_DB_PERSIST=true
|
| 161 |
+
|
| 162 |
+
# ChromaDB server host (for remote ChromaDB, optional)
|
| 163 |
+
# CHROMA_DB_HOST=localhost
|
| 164 |
+
|
| 165 |
+
# ChromaDB server port (for remote ChromaDB, optional)
|
| 166 |
+
# CHROMA_DB_PORT=8000
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
### External Services
|
| 170 |
+
|
| 171 |
+
```bash
|
| 172 |
+
# Modal Token ID (for Modal sandbox execution)
|
| 173 |
+
MODAL_TOKEN_ID=your_modal_token_id_here
|
| 174 |
+
|
| 175 |
+
# Modal Token Secret
|
| 176 |
+
MODAL_TOKEN_SECRET=your_modal_token_secret_here
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
### Logging Configuration
|
| 180 |
+
|
| 181 |
+
```bash
|
| 182 |
+
# Log Level: "DEBUG", "INFO", "WARNING", or "ERROR"
|
| 183 |
+
LOG_LEVEL=INFO
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
## Configuration Properties
|
| 187 |
+
|
| 188 |
+
The `Settings` class provides helpful properties for checking configuration:
|
| 189 |
+
|
| 190 |
+
```python
|
| 191 |
+
from src.utils.config import settings
|
| 192 |
+
|
| 193 |
+
# Check API key availability
|
| 194 |
+
settings.has_openai_key # bool
|
| 195 |
+
settings.has_anthropic_key # bool
|
| 196 |
+
settings.has_huggingface_key # bool
|
| 197 |
+
settings.has_any_llm_key # bool
|
| 198 |
+
|
| 199 |
+
# Check service availability
|
| 200 |
+
settings.modal_available # bool
|
| 201 |
+
settings.web_search_available # bool
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
## Environment Variables Reference
|
| 205 |
+
|
| 206 |
+
### Required (at least one LLM)
|
| 207 |
+
- `OPENAI_API_KEY` or `ANTHROPIC_API_KEY` - At least one LLM provider key
|
| 208 |
+
|
| 209 |
+
### Optional LLM Providers
|
| 210 |
+
- `DEEPSEEK_API_KEY` (Phase 2)
|
| 211 |
+
- `OPENROUTER_API_KEY` (Phase 2)
|
| 212 |
+
- `GEMINI_API_KEY` (Phase 2)
|
| 213 |
+
- `PERPLEXITY_API_KEY` (Phase 2)
|
| 214 |
+
- `HUGGINGFACE_API_KEY` or `HF_TOKEN`
|
| 215 |
+
- `AZURE_OPENAI_ENDPOINT` (Phase 2)
|
| 216 |
+
- `AZURE_OPENAI_DEPLOYMENT` (Phase 2)
|
| 217 |
+
- `AZURE_OPENAI_API_KEY` (Phase 2)
|
| 218 |
+
- `AZURE_OPENAI_API_VERSION` (Phase 2)
|
| 219 |
+
- `LOCAL_MODEL_URL` (Phase 2)
|
| 220 |
+
|
| 221 |
+
### Web Search
|
| 222 |
+
- `WEB_SEARCH_PROVIDER` (default: "duckduckgo")
|
| 223 |
+
- `SERPER_API_KEY`
|
| 224 |
+
- `SEARCHXNG_HOST`
|
| 225 |
+
- `BRAVE_API_KEY`
|
| 226 |
+
- `TAVILY_API_KEY`
|
| 227 |
+
|
| 228 |
+
### Embeddings
|
| 229 |
+
- `EMBEDDING_PROVIDER` (default: "local")
|
| 230 |
+
- `HUGGINGFACE_EMBEDDING_MODEL` (optional)
|
| 231 |
+
|
| 232 |
+
### RAG
|
| 233 |
+
- `RAG_COLLECTION_NAME` (default: "deepcritical_evidence")
|
| 234 |
+
- `RAG_SIMILARITY_TOP_K` (default: 5)
|
| 235 |
+
- `RAG_AUTO_INGEST` (default: true)
|
| 236 |
+
|
| 237 |
+
### ChromaDB
|
| 238 |
+
- `CHROMA_DB_PATH` (default: "./chroma_db")
|
| 239 |
+
- `CHROMA_DB_PERSIST` (default: true)
|
| 240 |
+
- `CHROMA_DB_HOST` (optional)
|
| 241 |
+
- `CHROMA_DB_PORT` (optional)
|
| 242 |
+
|
| 243 |
+
### Budget
|
| 244 |
+
- `DEFAULT_TOKEN_LIMIT` (default: 100000)
|
| 245 |
+
- `DEFAULT_TIME_LIMIT_MINUTES` (default: 10)
|
| 246 |
+
- `DEFAULT_ITERATIONS_LIMIT` (default: 10)
|
| 247 |
+
|
| 248 |
+
### Other
|
| 249 |
+
- `LLM_PROVIDER` (default: "openai")
|
| 250 |
+
- `NCBI_API_KEY` (optional)
|
| 251 |
+
- `MODAL_TOKEN_ID` (optional)
|
| 252 |
+
- `MODAL_TOKEN_SECRET` (optional)
|
| 253 |
+
- `MAX_ITERATIONS` (default: 10)
|
| 254 |
+
- `LOG_LEVEL` (default: "INFO")
|
| 255 |
+
- `USE_GRAPH_EXECUTION` (default: false)
|
| 256 |
+
|
| 257 |
+
## Validation
|
| 258 |
+
|
| 259 |
+
Settings are validated on load using Pydantic validation:
|
| 260 |
+
|
| 261 |
+
- **Type checking**: All fields are strongly typed
|
| 262 |
+
- **Range validation**: Numeric fields have min/max constraints
|
| 263 |
+
- **Literal validation**: Enum fields only accept specific values
|
| 264 |
+
- **Required fields**: API keys are checked when accessed via `get_api_key()`
|
| 265 |
+
|
| 266 |
+
## Error Handling
|
| 267 |
+
|
| 268 |
+
Configuration errors raise `ConfigurationError`:
|
| 269 |
+
|
| 270 |
+
```python
|
| 271 |
+
from src.utils.config import settings
|
| 272 |
+
from src.utils.exceptions import ConfigurationError
|
| 273 |
+
|
| 274 |
+
try:
|
| 275 |
+
api_key = settings.get_api_key()
|
| 276 |
+
except ConfigurationError as e:
|
| 277 |
+
print(f"Configuration error: {e}")
|
| 278 |
+
```
|
| 279 |
+
|
| 280 |
+
## Future Enhancements (Phase 2)
|
| 281 |
+
|
| 282 |
+
The following configurations are planned for Phase 2:
|
| 283 |
+
|
| 284 |
+
1. **Additional LLM Providers**: DeepSeek, OpenRouter, Gemini, Perplexity, Azure OpenAI, Local models
|
| 285 |
+
2. **Model Selection**: Reasoning/main/fast model configuration
|
| 286 |
+
3. **Service Integration**: Migrate `folder/llm_config.py` to centralized config
|
| 287 |
+
|
| 288 |
+
See `CONFIGURATION_ANALYSIS.md` for the complete implementation plan.
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
|
docs/architecture/graph_orchestration.md
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Graph Orchestration Architecture
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
Phase 4 implements a graph-based orchestration system for research workflows using Pydantic AI agents as nodes. This enables better parallel execution, conditional routing, and state management compared to simple agent chains.
|
| 6 |
+
|
| 7 |
+
## Graph Structure
|
| 8 |
+
|
| 9 |
+
### Nodes
|
| 10 |
+
|
| 11 |
+
Graph nodes represent different stages in the research workflow:
|
| 12 |
+
|
| 13 |
+
1. **Agent Nodes**: Execute Pydantic AI agents
|
| 14 |
+
- Input: Prompt/query
|
| 15 |
+
- Output: Structured or unstructured response
|
| 16 |
+
- Examples: `KnowledgeGapAgent`, `ToolSelectorAgent`, `ThinkingAgent`
|
| 17 |
+
|
| 18 |
+
2. **State Nodes**: Update or read workflow state
|
| 19 |
+
- Input: Current state
|
| 20 |
+
- Output: Updated state
|
| 21 |
+
- Examples: Update evidence, update conversation history
|
| 22 |
+
|
| 23 |
+
3. **Decision Nodes**: Make routing decisions based on conditions
|
| 24 |
+
- Input: Current state/results
|
| 25 |
+
- Output: Next node ID
|
| 26 |
+
- Examples: Continue research vs. complete research
|
| 27 |
+
|
| 28 |
+
4. **Parallel Nodes**: Execute multiple nodes concurrently
|
| 29 |
+
- Input: List of node IDs
|
| 30 |
+
- Output: Aggregated results
|
| 31 |
+
- Examples: Parallel iterative research loops
|
| 32 |
+
|
| 33 |
+
### Edges
|
| 34 |
+
|
| 35 |
+
Edges define transitions between nodes:
|
| 36 |
+
|
| 37 |
+
1. **Sequential Edges**: Always traversed (no condition)
|
| 38 |
+
- From: Source node
|
| 39 |
+
- To: Target node
|
| 40 |
+
- Condition: None (always True)
|
| 41 |
+
|
| 42 |
+
2. **Conditional Edges**: Traversed based on condition
|
| 43 |
+
- From: Source node
|
| 44 |
+
- To: Target node
|
| 45 |
+
- Condition: Callable that returns bool
|
| 46 |
+
- Example: If research complete → go to writer, else → continue loop
|
| 47 |
+
|
| 48 |
+
3. **Parallel Edges**: Used for parallel execution branches
|
| 49 |
+
- From: Parallel node
|
| 50 |
+
- To: Multiple target nodes
|
| 51 |
+
- Execution: All targets run concurrently
|
| 52 |
+
|
| 53 |
+
## Graph Patterns
|
| 54 |
+
|
| 55 |
+
### Iterative Research Graph
|
| 56 |
+
|
| 57 |
+
```
|
| 58 |
+
[Input] → [Thinking] → [Knowledge Gap] → [Decision: Complete?]
|
| 59 |
+
↓ No ↓ Yes
|
| 60 |
+
[Tool Selector] [Writer]
|
| 61 |
+
↓
|
| 62 |
+
[Execute Tools] → [Loop Back]
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### Deep Research Graph
|
| 66 |
+
|
| 67 |
+
```
|
| 68 |
+
[Input] → [Planner] → [Parallel Iterative Loops] → [Synthesizer]
|
| 69 |
+
↓ ↓ ↓
|
| 70 |
+
[Loop1] [Loop2] [Loop3]
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## State Management
|
| 74 |
+
|
| 75 |
+
State is managed via `WorkflowState` using `ContextVar` for thread-safe isolation:
|
| 76 |
+
|
| 77 |
+
- **Evidence**: Collected evidence from searches
|
| 78 |
+
- **Conversation**: Iteration history (gaps, tool calls, findings, thoughts)
|
| 79 |
+
- **Embedding Service**: For semantic search
|
| 80 |
+
|
| 81 |
+
State transitions occur at state nodes, which update the global workflow state.
|
| 82 |
+
|
| 83 |
+
## Execution Flow
|
| 84 |
+
|
| 85 |
+
1. **Graph Construction**: Build graph from nodes and edges
|
| 86 |
+
2. **Graph Validation**: Ensure graph is valid (no cycles, all nodes reachable)
|
| 87 |
+
3. **Graph Execution**: Traverse graph from entry node
|
| 88 |
+
4. **Node Execution**: Execute each node based on type
|
| 89 |
+
5. **Edge Evaluation**: Determine next node(s) based on edges
|
| 90 |
+
6. **Parallel Execution**: Use `asyncio.gather()` for parallel nodes
|
| 91 |
+
7. **State Updates**: Update state at state nodes
|
| 92 |
+
8. **Event Streaming**: Yield events during execution for UI
|
| 93 |
+
|
| 94 |
+
## Conditional Routing
|
| 95 |
+
|
| 96 |
+
Decision nodes evaluate conditions and return next node IDs:
|
| 97 |
+
|
| 98 |
+
- **Knowledge Gap Decision**: If `research_complete` → writer, else → tool selector
|
| 99 |
+
- **Budget Decision**: If budget exceeded → exit, else → continue
|
| 100 |
+
- **Iteration Decision**: If max iterations → exit, else → continue
|
| 101 |
+
|
| 102 |
+
## Parallel Execution
|
| 103 |
+
|
| 104 |
+
Parallel nodes execute multiple nodes concurrently:
|
| 105 |
+
|
| 106 |
+
- Each parallel branch runs independently
|
| 107 |
+
- Results are aggregated after all branches complete
|
| 108 |
+
- State is synchronized after parallel execution
|
| 109 |
+
- Errors in one branch don't stop other branches
|
| 110 |
+
|
| 111 |
+
## Budget Enforcement
|
| 112 |
+
|
| 113 |
+
Budget constraints are enforced at decision nodes:
|
| 114 |
+
|
| 115 |
+
- **Token Budget**: Track LLM token usage
|
| 116 |
+
- **Time Budget**: Track elapsed time
|
| 117 |
+
- **Iteration Budget**: Track iteration count
|
| 118 |
+
|
| 119 |
+
If any budget is exceeded, execution routes to exit node.
|
| 120 |
+
|
| 121 |
+
## Error Handling
|
| 122 |
+
|
| 123 |
+
Errors are handled at multiple levels:
|
| 124 |
+
|
| 125 |
+
1. **Node Level**: Catch errors in individual node execution
|
| 126 |
+
2. **Graph Level**: Handle errors during graph traversal
|
| 127 |
+
3. **State Level**: Rollback state changes on error
|
| 128 |
+
|
| 129 |
+
Errors are logged and yield error events for UI.
|
| 130 |
+
|
| 131 |
+
## Backward Compatibility
|
| 132 |
+
|
| 133 |
+
Graph execution is optional via feature flag:
|
| 134 |
+
|
| 135 |
+
- `USE_GRAPH_EXECUTION=true`: Use graph-based execution
|
| 136 |
+
- `USE_GRAPH_EXECUTION=false`: Use agent chain execution (existing)
|
| 137 |
+
|
| 138 |
+
This allows gradual migration and fallback if needed.
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
docs/examples/writer_agents_usage.md
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Writer Agents Usage Examples
|
| 2 |
+
|
| 3 |
+
This document provides examples of how to use the writer agents in DeepCritical for generating research reports.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
DeepCritical provides three writer agents for different report generation scenarios:
|
| 8 |
+
|
| 9 |
+
1. **WriterAgent** - Basic writer for simple reports from findings
|
| 10 |
+
2. **LongWriterAgent** - Iterative writer for long-form multi-section reports
|
| 11 |
+
3. **ProofreaderAgent** - Finalizes and polishes report drafts
|
| 12 |
+
|
| 13 |
+
## WriterAgent
|
| 14 |
+
|
| 15 |
+
The `WriterAgent` generates final reports from research findings. It's used in iterative research flows.
|
| 16 |
+
|
| 17 |
+
### Basic Usage
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
from src.agent_factory.agents import create_writer_agent
|
| 21 |
+
|
| 22 |
+
# Create writer agent
|
| 23 |
+
writer = create_writer_agent()
|
| 24 |
+
|
| 25 |
+
# Generate report
|
| 26 |
+
query = "What is the capital of France?"
|
| 27 |
+
findings = """
|
| 28 |
+
Paris is the capital of France [1].
|
| 29 |
+
It is located in the north-central part of the country [2].
|
| 30 |
+
|
| 31 |
+
[1] https://example.com/france-info
|
| 32 |
+
[2] https://example.com/paris-info
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
report = await writer.write_report(
|
| 36 |
+
query=query,
|
| 37 |
+
findings=findings,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
print(report)
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
### With Output Length Specification
|
| 44 |
+
|
| 45 |
+
```python
|
| 46 |
+
report = await writer.write_report(
|
| 47 |
+
query="Explain machine learning",
|
| 48 |
+
findings=findings,
|
| 49 |
+
output_length="500 words",
|
| 50 |
+
)
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
### With Additional Instructions
|
| 54 |
+
|
| 55 |
+
```python
|
| 56 |
+
report = await writer.write_report(
|
| 57 |
+
query="Explain machine learning",
|
| 58 |
+
findings=findings,
|
| 59 |
+
output_length="A comprehensive overview",
|
| 60 |
+
output_instructions="Use formal academic language and include examples",
|
| 61 |
+
)
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
### Integration with IterativeResearchFlow
|
| 65 |
+
|
| 66 |
+
The `WriterAgent` is automatically used by `IterativeResearchFlow`:
|
| 67 |
+
|
| 68 |
+
```python
|
| 69 |
+
from src.agent_factory.agents import create_iterative_flow
|
| 70 |
+
|
| 71 |
+
flow = create_iterative_flow(max_iterations=5, max_time_minutes=10)
|
| 72 |
+
report = await flow.run(
|
| 73 |
+
query="What is quantum computing?",
|
| 74 |
+
output_length="A detailed explanation",
|
| 75 |
+
output_instructions="Include practical applications",
|
| 76 |
+
)
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## LongWriterAgent
|
| 80 |
+
|
| 81 |
+
The `LongWriterAgent` iteratively writes report sections with proper citation management. It's used in deep research flows.
|
| 82 |
+
|
| 83 |
+
### Basic Usage
|
| 84 |
+
|
| 85 |
+
```python
|
| 86 |
+
from src.agent_factory.agents import create_long_writer_agent
|
| 87 |
+
from src.utils.models import ReportDraft, ReportDraftSection
|
| 88 |
+
|
| 89 |
+
# Create long writer agent
|
| 90 |
+
long_writer = create_long_writer_agent()
|
| 91 |
+
|
| 92 |
+
# Create report draft with sections
|
| 93 |
+
report_draft = ReportDraft(
|
| 94 |
+
sections=[
|
| 95 |
+
ReportDraftSection(
|
| 96 |
+
section_title="Introduction",
|
| 97 |
+
section_content="Draft content for introduction with [1].",
|
| 98 |
+
),
|
| 99 |
+
ReportDraftSection(
|
| 100 |
+
section_title="Methods",
|
| 101 |
+
section_content="Draft content for methods with [2].",
|
| 102 |
+
),
|
| 103 |
+
ReportDraftSection(
|
| 104 |
+
section_title="Results",
|
| 105 |
+
section_content="Draft content for results with [3].",
|
| 106 |
+
),
|
| 107 |
+
]
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Generate full report
|
| 111 |
+
report = await long_writer.write_report(
|
| 112 |
+
original_query="What are the main features of Python?",
|
| 113 |
+
report_title="Python Programming Language Overview",
|
| 114 |
+
report_draft=report_draft,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
print(report)
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
### Writing Individual Sections
|
| 121 |
+
|
| 122 |
+
You can also write sections one at a time:
|
| 123 |
+
|
| 124 |
+
```python
|
| 125 |
+
# Write first section
|
| 126 |
+
section_output = await long_writer.write_next_section(
|
| 127 |
+
original_query="What is Python?",
|
| 128 |
+
report_draft="", # No existing draft
|
| 129 |
+
next_section_title="Introduction",
|
| 130 |
+
next_section_draft="Python is a programming language...",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
print(section_output.next_section_markdown)
|
| 134 |
+
print(section_output.references)
|
| 135 |
+
|
| 136 |
+
# Write second section with existing draft
|
| 137 |
+
section_output = await long_writer.write_next_section(
|
| 138 |
+
original_query="What is Python?",
|
| 139 |
+
report_draft="# Report\n\n## Introduction\n\nContent...",
|
| 140 |
+
next_section_title="Features",
|
| 141 |
+
next_section_draft="Python features include...",
|
| 142 |
+
)
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
### Integration with DeepResearchFlow
|
| 146 |
+
|
| 147 |
+
The `LongWriterAgent` is automatically used by `DeepResearchFlow`:
|
| 148 |
+
|
| 149 |
+
```python
|
| 150 |
+
from src.agent_factory.agents import create_deep_flow
|
| 151 |
+
|
| 152 |
+
flow = create_deep_flow(
|
| 153 |
+
max_iterations=5,
|
| 154 |
+
max_time_minutes=10,
|
| 155 |
+
use_long_writer=True, # Use long writer (default)
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
report = await flow.run("What are the main features of Python programming language?")
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
## ProofreaderAgent
|
| 162 |
+
|
| 163 |
+
The `ProofreaderAgent` finalizes and polishes report drafts by removing duplicates, adding summaries, and refining wording.
|
| 164 |
+
|
| 165 |
+
### Basic Usage
|
| 166 |
+
|
| 167 |
+
```python
|
| 168 |
+
from src.agent_factory.agents import create_proofreader_agent
|
| 169 |
+
from src.utils.models import ReportDraft, ReportDraftSection
|
| 170 |
+
|
| 171 |
+
# Create proofreader agent
|
| 172 |
+
proofreader = create_proofreader_agent()
|
| 173 |
+
|
| 174 |
+
# Create report draft
|
| 175 |
+
report_draft = ReportDraft(
|
| 176 |
+
sections=[
|
| 177 |
+
ReportDraftSection(
|
| 178 |
+
section_title="Introduction",
|
| 179 |
+
section_content="Python is a programming language [1].",
|
| 180 |
+
),
|
| 181 |
+
ReportDraftSection(
|
| 182 |
+
section_title="Features",
|
| 183 |
+
section_content="Python has many features [2].",
|
| 184 |
+
),
|
| 185 |
+
]
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Proofread and finalize
|
| 189 |
+
final_report = await proofreader.proofread(
|
| 190 |
+
query="What is Python?",
|
| 191 |
+
report_draft=report_draft,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
print(final_report)
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
### Integration with DeepResearchFlow
|
| 198 |
+
|
| 199 |
+
Use `ProofreaderAgent` instead of `LongWriterAgent`:
|
| 200 |
+
|
| 201 |
+
```python
|
| 202 |
+
from src.agent_factory.agents import create_deep_flow
|
| 203 |
+
|
| 204 |
+
flow = create_deep_flow(
|
| 205 |
+
max_iterations=5,
|
| 206 |
+
max_time_minutes=10,
|
| 207 |
+
use_long_writer=False, # Use proofreader instead
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
report = await flow.run("What are the main features of Python?")
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
## Error Handling
|
| 214 |
+
|
| 215 |
+
All writer agents include robust error handling:
|
| 216 |
+
|
| 217 |
+
### Handling Empty Inputs
|
| 218 |
+
|
| 219 |
+
```python
|
| 220 |
+
# WriterAgent handles empty findings gracefully
|
| 221 |
+
report = await writer.write_report(
|
| 222 |
+
query="Test query",
|
| 223 |
+
findings="", # Empty findings
|
| 224 |
+
)
|
| 225 |
+
# Returns a fallback report
|
| 226 |
+
|
| 227 |
+
# LongWriterAgent handles empty sections
|
| 228 |
+
report = await long_writer.write_report(
|
| 229 |
+
original_query="Test",
|
| 230 |
+
report_title="Test Report",
|
| 231 |
+
report_draft=ReportDraft(sections=[]), # Empty draft
|
| 232 |
+
)
|
| 233 |
+
# Returns minimal report
|
| 234 |
+
|
| 235 |
+
# ProofreaderAgent handles empty drafts
|
| 236 |
+
report = await proofreader.proofread(
|
| 237 |
+
query="Test",
|
| 238 |
+
report_draft=ReportDraft(sections=[]),
|
| 239 |
+
)
|
| 240 |
+
# Returns minimal report
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
### Retry Logic
|
| 244 |
+
|
| 245 |
+
All agents automatically retry on transient errors (timeouts, connection errors):
|
| 246 |
+
|
| 247 |
+
```python
|
| 248 |
+
# Automatically retries up to 3 times on transient failures
|
| 249 |
+
report = await writer.write_report(
|
| 250 |
+
query="Test query",
|
| 251 |
+
findings=findings,
|
| 252 |
+
)
|
| 253 |
+
```
|
| 254 |
+
|
| 255 |
+
### Fallback Reports
|
| 256 |
+
|
| 257 |
+
If all retries fail, agents return fallback reports:
|
| 258 |
+
|
| 259 |
+
```python
|
| 260 |
+
# Returns fallback report with query and findings
|
| 261 |
+
report = await writer.write_report(
|
| 262 |
+
query="Test query",
|
| 263 |
+
findings=findings,
|
| 264 |
+
)
|
| 265 |
+
# Fallback includes: "# Research Report\n\n## Query\n...\n\n## Findings\n..."
|
| 266 |
+
```
|
| 267 |
+
|
| 268 |
+
## Citation Validation
|
| 269 |
+
|
| 270 |
+
### For Markdown Reports
|
| 271 |
+
|
| 272 |
+
Use the markdown citation validator:
|
| 273 |
+
|
| 274 |
+
```python
|
| 275 |
+
from src.utils.citation_validator import validate_markdown_citations
|
| 276 |
+
from src.utils.models import Evidence, Citation
|
| 277 |
+
|
| 278 |
+
# Collect evidence during research
|
| 279 |
+
evidence = [
|
| 280 |
+
Evidence(
|
| 281 |
+
content="Paris is the capital of France",
|
| 282 |
+
citation=Citation(
|
| 283 |
+
source="web",
|
| 284 |
+
title="France Information",
|
| 285 |
+
url="https://example.com/france",
|
| 286 |
+
date="2024-01-01",
|
| 287 |
+
),
|
| 288 |
+
),
|
| 289 |
+
]
|
| 290 |
+
|
| 291 |
+
# Generate report
|
| 292 |
+
report = await writer.write_report(query="What is the capital of France?", findings=findings)
|
| 293 |
+
|
| 294 |
+
# Validate citations
|
| 295 |
+
validated_report, removed_count = validate_markdown_citations(report, evidence)
|
| 296 |
+
|
| 297 |
+
if removed_count > 0:
|
| 298 |
+
print(f"Removed {removed_count} invalid citations")
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
### For ResearchReport Objects
|
| 302 |
+
|
| 303 |
+
Use the structured citation validator:
|
| 304 |
+
|
| 305 |
+
```python
|
| 306 |
+
from src.utils.citation_validator import validate_references
|
| 307 |
+
|
| 308 |
+
# For ResearchReport objects (from ReportAgent)
|
| 309 |
+
validated_report = validate_references(report, evidence)
|
| 310 |
+
```
|
| 311 |
+
|
| 312 |
+
## Custom Model Configuration
|
| 313 |
+
|
| 314 |
+
All writer agents support custom model configuration:
|
| 315 |
+
|
| 316 |
+
```python
|
| 317 |
+
from pydantic_ai import Model
|
| 318 |
+
|
| 319 |
+
# Create custom model
|
| 320 |
+
custom_model = Model("openai", "gpt-4")
|
| 321 |
+
|
| 322 |
+
# Use with writer agents
|
| 323 |
+
writer = create_writer_agent(model=custom_model)
|
| 324 |
+
long_writer = create_long_writer_agent(model=custom_model)
|
| 325 |
+
proofreader = create_proofreader_agent(model=custom_model)
|
| 326 |
+
```
|
| 327 |
+
|
| 328 |
+
## Best Practices
|
| 329 |
+
|
| 330 |
+
1. **Use WriterAgent for simple reports** - When you have findings as a string and need a quick report
|
| 331 |
+
2. **Use LongWriterAgent for structured reports** - When you need multiple sections with proper citation management
|
| 332 |
+
3. **Use ProofreaderAgent for final polish** - When you have draft sections and need a polished final report
|
| 333 |
+
4. **Validate citations** - Always validate citations against collected evidence
|
| 334 |
+
5. **Handle errors gracefully** - All agents return fallback reports on failure
|
| 335 |
+
6. **Specify output length** - Use `output_length` parameter to control report size
|
| 336 |
+
7. **Provide instructions** - Use `output_instructions` for specific formatting requirements
|
| 337 |
+
|
| 338 |
+
## Integration Examples
|
| 339 |
+
|
| 340 |
+
### Full Iterative Research Flow
|
| 341 |
+
|
| 342 |
+
```python
|
| 343 |
+
from src.agent_factory.agents import create_iterative_flow
|
| 344 |
+
|
| 345 |
+
flow = create_iterative_flow(
|
| 346 |
+
max_iterations=5,
|
| 347 |
+
max_time_minutes=10,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
report = await flow.run(
|
| 351 |
+
query="What is machine learning?",
|
| 352 |
+
output_length="A comprehensive 1000-word explanation",
|
| 353 |
+
output_instructions="Include practical examples and use cases",
|
| 354 |
+
)
|
| 355 |
+
```
|
| 356 |
+
|
| 357 |
+
### Full Deep Research Flow with Long Writer
|
| 358 |
+
|
| 359 |
+
```python
|
| 360 |
+
from src.agent_factory.agents import create_deep_flow
|
| 361 |
+
|
| 362 |
+
flow = create_deep_flow(
|
| 363 |
+
max_iterations=5,
|
| 364 |
+
max_time_minutes=10,
|
| 365 |
+
use_long_writer=True,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
report = await flow.run("What are the main features of Python programming language?")
|
| 369 |
+
```
|
| 370 |
+
|
| 371 |
+
### Full Deep Research Flow with Proofreader
|
| 372 |
+
|
| 373 |
+
```python
|
| 374 |
+
from src.agent_factory.agents import create_deep_flow
|
| 375 |
+
|
| 376 |
+
flow = create_deep_flow(
|
| 377 |
+
max_iterations=5,
|
| 378 |
+
max_time_minutes=10,
|
| 379 |
+
use_long_writer=False, # Use proofreader
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
report = await flow.run("Explain quantum computing basics")
|
| 383 |
+
```
|
| 384 |
+
|
| 385 |
+
## Troubleshooting
|
| 386 |
+
|
| 387 |
+
### Empty Reports
|
| 388 |
+
|
| 389 |
+
If you get empty reports, check:
|
| 390 |
+
- Input validation logs (agents log warnings for empty inputs)
|
| 391 |
+
- LLM API key configuration
|
| 392 |
+
- Network connectivity
|
| 393 |
+
|
| 394 |
+
### Citation Issues
|
| 395 |
+
|
| 396 |
+
If citations are missing or invalid:
|
| 397 |
+
- Use `validate_markdown_citations()` to check citations
|
| 398 |
+
- Ensure Evidence objects are properly collected during research
|
| 399 |
+
- Check that URLs in findings match Evidence URLs
|
| 400 |
+
|
| 401 |
+
### Performance Issues
|
| 402 |
+
|
| 403 |
+
For large reports:
|
| 404 |
+
- Use `LongWriterAgent` for better section management
|
| 405 |
+
- Consider truncating very long findings (agents do this automatically)
|
| 406 |
+
- Use appropriate `max_time_minutes` settings
|
| 407 |
+
|
| 408 |
+
## See Also
|
| 409 |
+
|
| 410 |
+
- [Research Flows Documentation](../orchestrator/research_flows.md)
|
| 411 |
+
- [Citation Validation](../utils/citation_validation.md)
|
| 412 |
+
- [Agent Factory](../agent_factory/agents.md)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
|
docs/implementation/02_phase_search.md
CHANGED
|
@@ -4,6 +4,8 @@
|
|
| 4 |
**Philosophy**: "Real data, mocked connections."
|
| 5 |
**Prerequisite**: Phase 1 complete (all tests passing)
|
| 6 |
|
|
|
|
|
|
|
| 7 |
---
|
| 8 |
|
| 9 |
## 1. The Slice Definition
|
|
@@ -12,17 +14,20 @@ This slice covers:
|
|
| 12 |
1. **Input**: A string query (e.g., "metformin Alzheimer's disease").
|
| 13 |
2. **Process**:
|
| 14 |
- Fetch from PubMed (E-utilities API).
|
| 15 |
-
- Fetch from Web (DuckDuckGo)
|
| 16 |
- Normalize results into `Evidence` models.
|
| 17 |
3. **Output**: A list of `Evidence` objects.
|
| 18 |
|
| 19 |
**Files to Create**:
|
| 20 |
- `src/utils/models.py` - Pydantic models (Evidence, Citation, SearchResult)
|
| 21 |
- `src/tools/pubmed.py` - PubMed E-utilities tool
|
| 22 |
-
-
|
| 23 |
- `src/tools/search_handler.py` - Orchestrates multiple tools
|
| 24 |
- `src/tools/__init__.py` - Exports
|
| 25 |
|
|
|
|
|
|
|
|
|
|
| 26 |
---
|
| 27 |
|
| 28 |
## 2. PubMed E-utilities API Reference
|
|
@@ -767,17 +772,23 @@ async def test_pubmed_live_search():
|
|
| 767 |
|
| 768 |
## 8. Implementation Checklist
|
| 769 |
|
| 770 |
-
- [
|
| 771 |
-
- [
|
| 772 |
-
- [
|
| 773 |
-
- [ ] Implement `src/tools/websearch.py` with WebTool class
|
| 774 |
-
- [
|
| 775 |
-
- [
|
| 776 |
-
- [ ] Write tests in `tests/unit/tools/test_websearch.py`
|
| 777 |
-
- [
|
| 778 |
-
- [
|
| 779 |
- [ ] (Optional) Run integration test: `uv run pytest -m integration`
|
| 780 |
-
- [ ]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
|
| 782 |
---
|
| 783 |
|
|
@@ -785,20 +796,19 @@ async def test_pubmed_live_search():
|
|
| 785 |
|
| 786 |
Phase 2 is **COMPLETE** when:
|
| 787 |
|
| 788 |
-
1. All unit tests pass: `uv run pytest tests/unit/tools/ -v`
|
| 789 |
-
2. `SearchHandler` can execute with
|
| 790 |
-
3. Graceful degradation: if
|
| 791 |
-
4. Rate limiting is enforced (verify no 429 errors)
|
| 792 |
-
5. Can run this in Python REPL:
|
| 793 |
|
| 794 |
```python
|
| 795 |
import asyncio
|
| 796 |
from src.tools.pubmed import PubMedTool
|
| 797 |
-
from src.tools.websearch import WebTool
|
| 798 |
from src.tools.search_handler import SearchHandler
|
| 799 |
|
| 800 |
async def test():
|
| 801 |
-
handler = SearchHandler([PubMedTool()
|
| 802 |
result = await handler.execute("metformin alzheimer")
|
| 803 |
print(f"Found {result.total_found} results")
|
| 804 |
for e in result.evidence[:3]:
|
|
@@ -807,4 +817,6 @@ async def test():
|
|
| 807 |
asyncio.run(test())
|
| 808 |
```
|
| 809 |
|
|
|
|
|
|
|
| 810 |
**Proceed to Phase 3 ONLY after all checkboxes are complete.**
|
|
|
|
| 4 |
**Philosophy**: "Real data, mocked connections."
|
| 5 |
**Prerequisite**: Phase 1 complete (all tests passing)
|
| 6 |
|
| 7 |
+
> **⚠️ Implementation Note (2025-01-27)**: The DuckDuckGo WebTool specified in this phase was removed in favor of the Europe PMC tool (see Phase 11). Europe PMC provides better coverage for biomedical research by including preprints, peer-reviewed articles, and patents. The current implementation uses PubMed, ClinicalTrials.gov, and Europe PMC as search sources.
|
| 8 |
+
|
| 9 |
---
|
| 10 |
|
| 11 |
## 1. The Slice Definition
|
|
|
|
| 14 |
1. **Input**: A string query (e.g., "metformin Alzheimer's disease").
|
| 15 |
2. **Process**:
|
| 16 |
- Fetch from PubMed (E-utilities API).
|
| 17 |
+
- ~~Fetch from Web (DuckDuckGo).~~ **REMOVED** - Replaced by Europe PMC in Phase 11
|
| 18 |
- Normalize results into `Evidence` models.
|
| 19 |
3. **Output**: A list of `Evidence` objects.
|
| 20 |
|
| 21 |
**Files to Create**:
|
| 22 |
- `src/utils/models.py` - Pydantic models (Evidence, Citation, SearchResult)
|
| 23 |
- `src/tools/pubmed.py` - PubMed E-utilities tool
|
| 24 |
+
- ~~`src/tools/websearch.py` - DuckDuckGo search tool~~ **REMOVED** - See Phase 11 for Europe PMC replacement
|
| 25 |
- `src/tools/search_handler.py` - Orchestrates multiple tools
|
| 26 |
- `src/tools/__init__.py` - Exports
|
| 27 |
|
| 28 |
+
**Additional Files (Post-Phase 2 Enhancements)**:
|
| 29 |
+
- `src/tools/query_utils.py` - Query preprocessing (removes question words, expands medical synonyms)
|
| 30 |
+
|
| 31 |
---
|
| 32 |
|
| 33 |
## 2. PubMed E-utilities API Reference
|
|
|
|
| 772 |
|
| 773 |
## 8. Implementation Checklist
|
| 774 |
|
| 775 |
+
- [x] Create `src/utils/models.py` with all Pydantic models (Evidence, Citation, SearchResult) - **COMPLETE**
|
| 776 |
+
- [x] Create `src/tools/__init__.py` with SearchTool Protocol and exports - **COMPLETE**
|
| 777 |
+
- [x] Implement `src/tools/pubmed.py` with PubMedTool class - **COMPLETE**
|
| 778 |
+
- [ ] ~~Implement `src/tools/websearch.py` with WebTool class~~ - **REMOVED** (replaced by Europe PMC in Phase 11)
|
| 779 |
+
- [x] Create `src/tools/search_handler.py` with SearchHandler class - **COMPLETE**
|
| 780 |
+
- [x] Write tests in `tests/unit/tools/test_pubmed.py` - **COMPLETE** (basic tests)
|
| 781 |
+
- [ ] Write tests in `tests/unit/tools/test_websearch.py` - **N/A** (WebTool removed)
|
| 782 |
+
- [x] Write tests in `tests/unit/tools/test_search_handler.py` - **COMPLETE** (basic tests)
|
| 783 |
+
- [x] Run `uv run pytest tests/unit/tools/ -v` — **ALL TESTS MUST PASS** - **PASSING**
|
| 784 |
- [ ] (Optional) Run integration test: `uv run pytest -m integration`
|
| 785 |
+
- [ ] Add edge case tests (rate limiting, error handling, timeouts) - **PENDING**
|
| 786 |
+
- [ ] Commit: `git commit -m "feat: phase 2 search slice complete"` - **DONE**
|
| 787 |
+
|
| 788 |
+
**Post-Phase 2 Enhancements**:
|
| 789 |
+
- [x] Query preprocessing (`src/tools/query_utils.py`) - **ADDED**
|
| 790 |
+
- [x] Europe PMC tool (Phase 11) - **ADDED**
|
| 791 |
+
- [x] ClinicalTrials tool (Phase 10) - **ADDED**
|
| 792 |
|
| 793 |
---
|
| 794 |
|
|
|
|
| 796 |
|
| 797 |
Phase 2 is **COMPLETE** when:
|
| 798 |
|
| 799 |
+
1. ✅ All unit tests pass: `uv run pytest tests/unit/tools/ -v` - **PASSING**
|
| 800 |
+
2. ✅ `SearchHandler` can execute with search tools - **WORKING**
|
| 801 |
+
3. ✅ Graceful degradation: if one tool fails, other tools still return results - **IMPLEMENTED**
|
| 802 |
+
4. ✅ Rate limiting is enforced (verify no 429 errors) - **IMPLEMENTED**
|
| 803 |
+
5. ✅ Can run this in Python REPL:
|
| 804 |
|
| 805 |
```python
|
| 806 |
import asyncio
|
| 807 |
from src.tools.pubmed import PubMedTool
|
|
|
|
| 808 |
from src.tools.search_handler import SearchHandler
|
| 809 |
|
| 810 |
async def test():
|
| 811 |
+
handler = SearchHandler([PubMedTool()])
|
| 812 |
result = await handler.execute("metformin alzheimer")
|
| 813 |
print(f"Found {result.total_found} results")
|
| 814 |
for e in result.evidence[:3]:
|
|
|
|
| 817 |
asyncio.run(test())
|
| 818 |
```
|
| 819 |
|
| 820 |
+
**Note**: WebTool was removed in favor of Europe PMC (Phase 11). The current implementation uses PubMed as the primary Phase 2 tool, with Europe PMC and ClinicalTrials added in later phases.
|
| 821 |
+
|
| 822 |
**Proceed to Phase 3 ONLY after all checkboxes are complete.**
|
pyproject.toml
CHANGED
|
@@ -24,6 +24,7 @@ dependencies = [
|
|
| 24 |
"tenacity>=8.2", # Retry logic
|
| 25 |
"structlog>=24.1", # Structured logging
|
| 26 |
"requests>=2.32.5", # ClinicalTrials.gov (httpx blocked by WAF)
|
|
|
|
| 27 |
]
|
| 28 |
|
| 29 |
[project.optional-dependencies]
|
|
@@ -91,6 +92,7 @@ ignore = [
|
|
| 91 |
"PLW0603", # Global statement (singleton pattern for Modal)
|
| 92 |
"PLC0415", # Lazy imports for optional dependencies
|
| 93 |
"E402", # Module level import not at top (needed for pytest.importorskip)
|
|
|
|
| 94 |
"RUF100", # Unused noqa (version differences between local/CI)
|
| 95 |
]
|
| 96 |
|
|
@@ -105,9 +107,12 @@ ignore_missing_imports = true
|
|
| 105 |
disallow_untyped_defs = true
|
| 106 |
warn_return_any = true
|
| 107 |
warn_unused_ignores = false
|
|
|
|
|
|
|
| 108 |
exclude = [
|
| 109 |
"^reference_repos/",
|
| 110 |
"^examples/",
|
|
|
|
| 111 |
]
|
| 112 |
|
| 113 |
# ============== PYTEST CONFIG ==============
|
|
@@ -137,5 +142,11 @@ exclude_lines = [
|
|
| 137 |
"raise NotImplementedError",
|
| 138 |
]
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
# Note: agent-framework-core is optional for magentic mode (multi-agent orchestration)
|
| 141 |
# Version pinned to 1.0.0b* to avoid breaking changes. CI skips tests via pytest.importorskip
|
|
|
|
| 24 |
"tenacity>=8.2", # Retry logic
|
| 25 |
"structlog>=24.1", # Structured logging
|
| 26 |
"requests>=2.32.5", # ClinicalTrials.gov (httpx blocked by WAF)
|
| 27 |
+
"pydantic-graph>=1.22.0",
|
| 28 |
]
|
| 29 |
|
| 30 |
[project.optional-dependencies]
|
|
|
|
| 92 |
"PLW0603", # Global statement (singleton pattern for Modal)
|
| 93 |
"PLC0415", # Lazy imports for optional dependencies
|
| 94 |
"E402", # Module level import not at top (needed for pytest.importorskip)
|
| 95 |
+
"E501", # Line too long (ignore line length violations)
|
| 96 |
"RUF100", # Unused noqa (version differences between local/CI)
|
| 97 |
]
|
| 98 |
|
|
|
|
| 107 |
disallow_untyped_defs = true
|
| 108 |
warn_return_any = true
|
| 109 |
warn_unused_ignores = false
|
| 110 |
+
explicit_package_bases = true
|
| 111 |
+
mypy_path = "."
|
| 112 |
exclude = [
|
| 113 |
"^reference_repos/",
|
| 114 |
"^examples/",
|
| 115 |
+
"^folder/",
|
| 116 |
]
|
| 117 |
|
| 118 |
# ============== PYTEST CONFIG ==============
|
|
|
|
| 142 |
"raise NotImplementedError",
|
| 143 |
]
|
| 144 |
|
| 145 |
+
[dependency-groups]
|
| 146 |
+
dev = [
|
| 147 |
+
"structlog>=25.5.0",
|
| 148 |
+
"ty>=0.0.1a28",
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
# Note: agent-framework-core is optional for magentic mode (multi-agent orchestration)
|
| 152 |
# Version pinned to 1.0.0b* to avoid breaking changes. CI skips tests via pytest.importorskip
|
src/agent_factory/agents.py
CHANGED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Agent factory functions for creating research agents.
|
| 2 |
+
|
| 3 |
+
Provides factory functions for creating all Pydantic AI agents used in
|
| 4 |
+
the research workflows, following the pattern from judges.py.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import TYPE_CHECKING, Any
|
| 8 |
+
|
| 9 |
+
import structlog
|
| 10 |
+
|
| 11 |
+
from src.utils.config import settings
|
| 12 |
+
from src.utils.exceptions import ConfigurationError
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from src.agent_factory.graph_builder import GraphBuilder
|
| 16 |
+
from src.agents.input_parser import InputParserAgent
|
| 17 |
+
from src.agents.knowledge_gap import KnowledgeGapAgent
|
| 18 |
+
from src.agents.long_writer import LongWriterAgent
|
| 19 |
+
from src.agents.proofreader import ProofreaderAgent
|
| 20 |
+
from src.agents.thinking import ThinkingAgent
|
| 21 |
+
from src.agents.tool_selector import ToolSelectorAgent
|
| 22 |
+
from src.agents.writer import WriterAgent
|
| 23 |
+
from src.orchestrator.graph_orchestrator import GraphOrchestrator
|
| 24 |
+
from src.orchestrator.planner_agent import PlannerAgent
|
| 25 |
+
from src.orchestrator.research_flow import DeepResearchFlow, IterativeResearchFlow
|
| 26 |
+
|
| 27 |
+
logger = structlog.get_logger()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def create_input_parser_agent(model: Any | None = None) -> "InputParserAgent":
|
| 31 |
+
"""
|
| 32 |
+
Create input parser agent for query analysis and research mode detection.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Configured InputParserAgent instance
|
| 39 |
+
|
| 40 |
+
Raises:
|
| 41 |
+
ConfigurationError: If required API keys are missing
|
| 42 |
+
"""
|
| 43 |
+
from src.agents.input_parser import create_input_parser_agent as _create_agent
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
logger.debug("Creating input parser agent")
|
| 47 |
+
return _create_agent(model=model)
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error("Failed to create input parser agent", error=str(e))
|
| 50 |
+
raise ConfigurationError(f"Failed to create input parser agent: {e}") from e
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def create_planner_agent(model: Any | None = None) -> "PlannerAgent":
|
| 54 |
+
"""
|
| 55 |
+
Create planner agent with web search and crawl tools.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Configured PlannerAgent instance
|
| 62 |
+
|
| 63 |
+
Raises:
|
| 64 |
+
ConfigurationError: If required API keys are missing
|
| 65 |
+
"""
|
| 66 |
+
# Lazy import to avoid circular dependencies
|
| 67 |
+
from src.orchestrator.planner_agent import create_planner_agent as _create_planner_agent
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
logger.debug("Creating planner agent")
|
| 71 |
+
return _create_planner_agent(model=model)
|
| 72 |
+
except Exception as e:
|
| 73 |
+
logger.error("Failed to create planner agent", error=str(e))
|
| 74 |
+
raise ConfigurationError(f"Failed to create planner agent: {e}") from e
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def create_knowledge_gap_agent(model: Any | None = None) -> "KnowledgeGapAgent":
|
| 78 |
+
"""
|
| 79 |
+
Create knowledge gap agent for evaluating research completeness.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Configured KnowledgeGapAgent instance
|
| 86 |
+
|
| 87 |
+
Raises:
|
| 88 |
+
ConfigurationError: If required API keys are missing
|
| 89 |
+
"""
|
| 90 |
+
from src.agents.knowledge_gap import create_knowledge_gap_agent as _create_agent
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
logger.debug("Creating knowledge gap agent")
|
| 94 |
+
return _create_agent(model=model)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.error("Failed to create knowledge gap agent", error=str(e))
|
| 97 |
+
raise ConfigurationError(f"Failed to create knowledge gap agent: {e}") from e
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def create_tool_selector_agent(model: Any | None = None) -> "ToolSelectorAgent":
|
| 101 |
+
"""
|
| 102 |
+
Create tool selector agent for choosing tools to address gaps.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Configured ToolSelectorAgent instance
|
| 109 |
+
|
| 110 |
+
Raises:
|
| 111 |
+
ConfigurationError: If required API keys are missing
|
| 112 |
+
"""
|
| 113 |
+
from src.agents.tool_selector import create_tool_selector_agent as _create_agent
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
logger.debug("Creating tool selector agent")
|
| 117 |
+
return _create_agent(model=model)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
logger.error("Failed to create tool selector agent", error=str(e))
|
| 120 |
+
raise ConfigurationError(f"Failed to create tool selector agent: {e}") from e
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def create_thinking_agent(model: Any | None = None) -> "ThinkingAgent":
|
| 124 |
+
"""
|
| 125 |
+
Create thinking agent for generating observations.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Configured ThinkingAgent instance
|
| 132 |
+
|
| 133 |
+
Raises:
|
| 134 |
+
ConfigurationError: If required API keys are missing
|
| 135 |
+
"""
|
| 136 |
+
from src.agents.thinking import create_thinking_agent as _create_agent
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
logger.debug("Creating thinking agent")
|
| 140 |
+
return _create_agent(model=model)
|
| 141 |
+
except Exception as e:
|
| 142 |
+
logger.error("Failed to create thinking agent", error=str(e))
|
| 143 |
+
raise ConfigurationError(f"Failed to create thinking agent: {e}") from e
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def create_writer_agent(model: Any | None = None) -> "WriterAgent":
|
| 147 |
+
"""
|
| 148 |
+
Create writer agent for generating final reports.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Configured WriterAgent instance
|
| 155 |
+
|
| 156 |
+
Raises:
|
| 157 |
+
ConfigurationError: If required API keys are missing
|
| 158 |
+
"""
|
| 159 |
+
from src.agents.writer import create_writer_agent as _create_agent
|
| 160 |
+
|
| 161 |
+
try:
|
| 162 |
+
logger.debug("Creating writer agent")
|
| 163 |
+
return _create_agent(model=model)
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logger.error("Failed to create writer agent", error=str(e))
|
| 166 |
+
raise ConfigurationError(f"Failed to create writer agent: {e}") from e
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def create_long_writer_agent(model: Any | None = None) -> "LongWriterAgent":
|
| 170 |
+
"""
|
| 171 |
+
Create long writer agent for iteratively writing report sections.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
Configured LongWriterAgent instance
|
| 178 |
+
|
| 179 |
+
Raises:
|
| 180 |
+
ConfigurationError: If required API keys are missing
|
| 181 |
+
"""
|
| 182 |
+
from src.agents.long_writer import create_long_writer_agent as _create_agent
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
logger.debug("Creating long writer agent")
|
| 186 |
+
return _create_agent(model=model)
|
| 187 |
+
except Exception as e:
|
| 188 |
+
logger.error("Failed to create long writer agent", error=str(e))
|
| 189 |
+
raise ConfigurationError(f"Failed to create long writer agent: {e}") from e
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def create_proofreader_agent(model: Any | None = None) -> "ProofreaderAgent":
|
| 193 |
+
"""
|
| 194 |
+
Create proofreader agent for finalizing report drafts.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
Configured ProofreaderAgent instance
|
| 201 |
+
|
| 202 |
+
Raises:
|
| 203 |
+
ConfigurationError: If required API keys are missing
|
| 204 |
+
"""
|
| 205 |
+
from src.agents.proofreader import create_proofreader_agent as _create_agent
|
| 206 |
+
|
| 207 |
+
try:
|
| 208 |
+
logger.debug("Creating proofreader agent")
|
| 209 |
+
return _create_agent(model=model)
|
| 210 |
+
except Exception as e:
|
| 211 |
+
logger.error("Failed to create proofreader agent", error=str(e))
|
| 212 |
+
raise ConfigurationError(f"Failed to create proofreader agent: {e}") from e
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def create_iterative_flow(
|
| 216 |
+
max_iterations: int = 5,
|
| 217 |
+
max_time_minutes: int = 10,
|
| 218 |
+
verbose: bool = True,
|
| 219 |
+
use_graph: bool | None = None,
|
| 220 |
+
) -> "IterativeResearchFlow":
|
| 221 |
+
"""
|
| 222 |
+
Create iterative research flow.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
max_iterations: Maximum number of iterations
|
| 226 |
+
max_time_minutes: Maximum time in minutes
|
| 227 |
+
verbose: Whether to log progress
|
| 228 |
+
use_graph: Whether to use graph execution. If None, reads from settings.use_graph_execution
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
Configured IterativeResearchFlow instance
|
| 232 |
+
"""
|
| 233 |
+
from src.orchestrator.research_flow import IterativeResearchFlow
|
| 234 |
+
|
| 235 |
+
try:
|
| 236 |
+
# Use settings default if not explicitly provided
|
| 237 |
+
if use_graph is None:
|
| 238 |
+
use_graph = settings.use_graph_execution
|
| 239 |
+
|
| 240 |
+
logger.debug("Creating iterative research flow", use_graph=use_graph)
|
| 241 |
+
return IterativeResearchFlow(
|
| 242 |
+
max_iterations=max_iterations,
|
| 243 |
+
max_time_minutes=max_time_minutes,
|
| 244 |
+
verbose=verbose,
|
| 245 |
+
use_graph=use_graph,
|
| 246 |
+
)
|
| 247 |
+
except Exception as e:
|
| 248 |
+
logger.error("Failed to create iterative flow", error=str(e))
|
| 249 |
+
raise ConfigurationError(f"Failed to create iterative flow: {e}") from e
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def create_deep_flow(
|
| 253 |
+
max_iterations: int = 5,
|
| 254 |
+
max_time_minutes: int = 10,
|
| 255 |
+
verbose: bool = True,
|
| 256 |
+
use_long_writer: bool = True,
|
| 257 |
+
use_graph: bool | None = None,
|
| 258 |
+
) -> "DeepResearchFlow":
|
| 259 |
+
"""
|
| 260 |
+
Create deep research flow.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
max_iterations: Maximum iterations per section
|
| 264 |
+
max_time_minutes: Maximum time per section
|
| 265 |
+
verbose: Whether to log progress
|
| 266 |
+
use_long_writer: Whether to use long writer (True) or proofreader (False)
|
| 267 |
+
use_graph: Whether to use graph execution. If None, reads from settings.use_graph_execution
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
Configured DeepResearchFlow instance
|
| 271 |
+
"""
|
| 272 |
+
from src.orchestrator.research_flow import DeepResearchFlow
|
| 273 |
+
|
| 274 |
+
try:
|
| 275 |
+
# Use settings default if not explicitly provided
|
| 276 |
+
if use_graph is None:
|
| 277 |
+
use_graph = settings.use_graph_execution
|
| 278 |
+
|
| 279 |
+
logger.debug("Creating deep research flow", use_graph=use_graph)
|
| 280 |
+
return DeepResearchFlow(
|
| 281 |
+
max_iterations=max_iterations,
|
| 282 |
+
max_time_minutes=max_time_minutes,
|
| 283 |
+
verbose=verbose,
|
| 284 |
+
use_long_writer=use_long_writer,
|
| 285 |
+
use_graph=use_graph,
|
| 286 |
+
)
|
| 287 |
+
except Exception as e:
|
| 288 |
+
logger.error("Failed to create deep flow", error=str(e))
|
| 289 |
+
raise ConfigurationError(f"Failed to create deep flow: {e}") from e
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def create_graph_orchestrator(
|
| 293 |
+
mode: str = "auto",
|
| 294 |
+
max_iterations: int = 5,
|
| 295 |
+
max_time_minutes: int = 10,
|
| 296 |
+
use_graph: bool = True,
|
| 297 |
+
) -> "GraphOrchestrator":
|
| 298 |
+
"""
|
| 299 |
+
Create graph orchestrator.
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
mode: Research mode ("iterative", "deep", or "auto")
|
| 303 |
+
max_iterations: Maximum iterations per loop
|
| 304 |
+
max_time_minutes: Maximum time per loop
|
| 305 |
+
use_graph: Whether to use graph execution (True) or agent chains (False)
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
Configured GraphOrchestrator instance
|
| 309 |
+
"""
|
| 310 |
+
from src.orchestrator.graph_orchestrator import create_graph_orchestrator as _create
|
| 311 |
+
|
| 312 |
+
try:
|
| 313 |
+
logger.debug("Creating graph orchestrator", mode=mode, use_graph=use_graph)
|
| 314 |
+
return _create(
|
| 315 |
+
mode=mode, # type: ignore[arg-type]
|
| 316 |
+
max_iterations=max_iterations,
|
| 317 |
+
max_time_minutes=max_time_minutes,
|
| 318 |
+
use_graph=use_graph,
|
| 319 |
+
)
|
| 320 |
+
except Exception as e:
|
| 321 |
+
logger.error("Failed to create graph orchestrator", error=str(e))
|
| 322 |
+
raise ConfigurationError(f"Failed to create graph orchestrator: {e}") from e
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def create_graph_builder() -> "GraphBuilder":
|
| 326 |
+
"""
|
| 327 |
+
Create a graph builder instance.
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
GraphBuilder instance
|
| 331 |
+
"""
|
| 332 |
+
from src.agent_factory.graph_builder import GraphBuilder
|
| 333 |
+
|
| 334 |
+
try:
|
| 335 |
+
logger.debug("Creating graph builder")
|
| 336 |
+
return GraphBuilder()
|
| 337 |
+
except Exception as e:
|
| 338 |
+
logger.error("Failed to create graph builder", error=str(e))
|
| 339 |
+
raise ConfigurationError(f"Failed to create graph builder: {e}") from e
|
src/agent_factory/graph_builder.py
ADDED
|
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Graph builder utilities for constructing research workflow graphs.
|
| 2 |
+
|
| 3 |
+
Provides classes and utilities for building graph-based orchestration systems
|
| 4 |
+
using Pydantic AI agents as nodes.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from collections.abc import Callable
|
| 8 |
+
from typing import TYPE_CHECKING, Any, Literal
|
| 9 |
+
|
| 10 |
+
import structlog
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from pydantic_ai import Agent
|
| 15 |
+
|
| 16 |
+
from src.middleware.state_machine import WorkflowState
|
| 17 |
+
|
| 18 |
+
logger = structlog.get_logger()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ============================================================================
|
| 22 |
+
# Graph Node Models
|
| 23 |
+
# ============================================================================
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class GraphNode(BaseModel):
|
| 27 |
+
"""Base class for graph nodes."""
|
| 28 |
+
|
| 29 |
+
node_id: str = Field(description="Unique identifier for the node")
|
| 30 |
+
node_type: Literal["agent", "state", "decision", "parallel"] = Field(description="Type of node")
|
| 31 |
+
description: str = Field(default="", description="Human-readable description of the node")
|
| 32 |
+
|
| 33 |
+
model_config = {"frozen": True}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class AgentNode(GraphNode):
|
| 37 |
+
"""Node that executes a Pydantic AI agent."""
|
| 38 |
+
|
| 39 |
+
node_type: Literal["agent"] = "agent"
|
| 40 |
+
agent: Any = Field(description="Pydantic AI agent to execute")
|
| 41 |
+
input_transformer: Callable[[Any], Any] | None = Field(
|
| 42 |
+
default=None, description="Transform input before passing to agent"
|
| 43 |
+
)
|
| 44 |
+
output_transformer: Callable[[Any], Any] | None = Field(
|
| 45 |
+
default=None, description="Transform output after agent execution"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
model_config = {"arbitrary_types_allowed": True}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class StateNode(GraphNode):
|
| 52 |
+
"""Node that updates or reads workflow state."""
|
| 53 |
+
|
| 54 |
+
node_type: Literal["state"] = "state"
|
| 55 |
+
state_updater: Callable[[Any, Any], Any] = Field(
|
| 56 |
+
description="Function to update workflow state"
|
| 57 |
+
)
|
| 58 |
+
state_reader: Callable[[Any], Any] | None = Field(
|
| 59 |
+
default=None, description="Function to read state (optional)"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
model_config = {"arbitrary_types_allowed": True}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class DecisionNode(GraphNode):
|
| 66 |
+
"""Node that makes routing decisions based on conditions."""
|
| 67 |
+
|
| 68 |
+
node_type: Literal["decision"] = "decision"
|
| 69 |
+
decision_function: Callable[[Any], str] = Field(
|
| 70 |
+
description="Function that returns next node ID based on input"
|
| 71 |
+
)
|
| 72 |
+
options: list[str] = Field(description="List of possible next node IDs", min_length=1)
|
| 73 |
+
|
| 74 |
+
model_config = {"arbitrary_types_allowed": True}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class ParallelNode(GraphNode):
|
| 78 |
+
"""Node that executes multiple nodes in parallel."""
|
| 79 |
+
|
| 80 |
+
node_type: Literal["parallel"] = "parallel"
|
| 81 |
+
parallel_nodes: list[str] = Field(
|
| 82 |
+
description="List of node IDs to run in parallel", min_length=1
|
| 83 |
+
)
|
| 84 |
+
aggregator: Callable[[list[Any]], Any] | None = Field(
|
| 85 |
+
default=None, description="Function to aggregate parallel results"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
model_config = {"arbitrary_types_allowed": True}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ============================================================================
|
| 92 |
+
# Graph Edge Models
|
| 93 |
+
# ============================================================================
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class GraphEdge(BaseModel):
|
| 97 |
+
"""Base class for graph edges."""
|
| 98 |
+
|
| 99 |
+
from_node: str = Field(description="Source node ID")
|
| 100 |
+
to_node: str = Field(description="Target node ID")
|
| 101 |
+
condition: Callable[[Any], bool] | None = Field(
|
| 102 |
+
default=None, description="Optional condition function"
|
| 103 |
+
)
|
| 104 |
+
weight: float = Field(default=1.0, description="Edge weight for routing decisions")
|
| 105 |
+
|
| 106 |
+
model_config = {"arbitrary_types_allowed": True}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class SequentialEdge(GraphEdge):
|
| 110 |
+
"""Edge that is always traversed (no condition)."""
|
| 111 |
+
|
| 112 |
+
condition: None = None
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ConditionalEdge(GraphEdge):
|
| 116 |
+
"""Edge that is traversed based on a condition."""
|
| 117 |
+
|
| 118 |
+
condition: Callable[[Any], bool] = Field(description="Required condition function")
|
| 119 |
+
condition_description: str = Field(
|
| 120 |
+
default="", description="Human-readable description of condition"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class ParallelEdge(GraphEdge):
|
| 125 |
+
"""Edge used for parallel execution branches."""
|
| 126 |
+
|
| 127 |
+
condition: None = None
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# ============================================================================
|
| 131 |
+
# Research Graph Class
|
| 132 |
+
# ============================================================================
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class ResearchGraph(BaseModel):
|
| 136 |
+
"""Represents a research workflow graph with nodes and edges."""
|
| 137 |
+
|
| 138 |
+
nodes: dict[str, GraphNode] = Field(default_factory=dict, description="All nodes in the graph")
|
| 139 |
+
edges: dict[str, list[GraphEdge]] = Field(
|
| 140 |
+
default_factory=dict, description="Edges by source node ID"
|
| 141 |
+
)
|
| 142 |
+
entry_node: str = Field(description="Starting node ID")
|
| 143 |
+
exit_nodes: list[str] = Field(default_factory=list, description="Terminal node IDs")
|
| 144 |
+
|
| 145 |
+
model_config = {"arbitrary_types_allowed": True}
|
| 146 |
+
|
| 147 |
+
def add_node(self, node: GraphNode) -> None:
|
| 148 |
+
"""Add a node to the graph.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
node: The node to add
|
| 152 |
+
|
| 153 |
+
Raises:
|
| 154 |
+
ValueError: If node ID already exists
|
| 155 |
+
"""
|
| 156 |
+
if node.node_id in self.nodes:
|
| 157 |
+
raise ValueError(f"Node {node.node_id} already exists in graph")
|
| 158 |
+
self.nodes[node.node_id] = node
|
| 159 |
+
logger.debug("Node added to graph", node_id=node.node_id, type=node.node_type)
|
| 160 |
+
|
| 161 |
+
def add_edge(self, edge: GraphEdge) -> None:
|
| 162 |
+
"""Add an edge to the graph.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
edge: The edge to add
|
| 166 |
+
|
| 167 |
+
Raises:
|
| 168 |
+
ValueError: If source or target node doesn't exist
|
| 169 |
+
"""
|
| 170 |
+
if edge.from_node not in self.nodes:
|
| 171 |
+
raise ValueError(f"Source node {edge.from_node} not found in graph")
|
| 172 |
+
if edge.to_node not in self.nodes:
|
| 173 |
+
raise ValueError(f"Target node {edge.to_node} not found in graph")
|
| 174 |
+
|
| 175 |
+
if edge.from_node not in self.edges:
|
| 176 |
+
self.edges[edge.from_node] = []
|
| 177 |
+
self.edges[edge.from_node].append(edge)
|
| 178 |
+
logger.debug(
|
| 179 |
+
"Edge added to graph",
|
| 180 |
+
from_node=edge.from_node,
|
| 181 |
+
to_node=edge.to_node,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def get_node(self, node_id: str) -> GraphNode | None:
|
| 185 |
+
"""Get a node by ID.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
node_id: The node ID
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
The node, or None if not found
|
| 192 |
+
"""
|
| 193 |
+
return self.nodes.get(node_id)
|
| 194 |
+
|
| 195 |
+
def get_next_nodes(self, node_id: str, context: Any = None) -> list[tuple[str, GraphEdge]]:
|
| 196 |
+
"""Get all possible next nodes from a given node.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
node_id: The current node ID
|
| 200 |
+
context: Optional context for evaluating conditions
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
List of (node_id, edge) tuples for valid next nodes
|
| 204 |
+
"""
|
| 205 |
+
if node_id not in self.edges:
|
| 206 |
+
return []
|
| 207 |
+
|
| 208 |
+
next_nodes = []
|
| 209 |
+
for edge in self.edges[node_id]:
|
| 210 |
+
# Evaluate condition if present
|
| 211 |
+
if edge.condition is None or edge.condition(context):
|
| 212 |
+
next_nodes.append((edge.to_node, edge))
|
| 213 |
+
|
| 214 |
+
return next_nodes
|
| 215 |
+
|
| 216 |
+
def validate_structure(self) -> list[str]:
|
| 217 |
+
"""Validate the graph structure.
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
List of validation error messages (empty if valid)
|
| 221 |
+
"""
|
| 222 |
+
errors = []
|
| 223 |
+
|
| 224 |
+
# Check entry node exists
|
| 225 |
+
if self.entry_node not in self.nodes:
|
| 226 |
+
errors.append(f"Entry node {self.entry_node} not found in graph")
|
| 227 |
+
|
| 228 |
+
# Check exit nodes exist and at least one is defined
|
| 229 |
+
if not self.exit_nodes:
|
| 230 |
+
errors.append("At least one exit node must be defined")
|
| 231 |
+
for exit_node in self.exit_nodes:
|
| 232 |
+
if exit_node not in self.nodes:
|
| 233 |
+
errors.append(f"Exit node {exit_node} not found in graph")
|
| 234 |
+
|
| 235 |
+
# Check all edges reference valid nodes
|
| 236 |
+
for from_node, edge_list in self.edges.items():
|
| 237 |
+
if from_node not in self.nodes:
|
| 238 |
+
errors.append(f"Edge source node {from_node} not found")
|
| 239 |
+
for edge in edge_list:
|
| 240 |
+
if edge.to_node not in self.nodes:
|
| 241 |
+
errors.append(f"Edge target node {edge.to_node} not found")
|
| 242 |
+
|
| 243 |
+
# Check all nodes are reachable from entry node (basic check)
|
| 244 |
+
if self.entry_node in self.nodes:
|
| 245 |
+
reachable = {self.entry_node}
|
| 246 |
+
queue = [self.entry_node]
|
| 247 |
+
while queue:
|
| 248 |
+
current = queue.pop(0)
|
| 249 |
+
for next_node, _ in self.get_next_nodes(current):
|
| 250 |
+
if next_node not in reachable:
|
| 251 |
+
reachable.add(next_node)
|
| 252 |
+
queue.append(next_node)
|
| 253 |
+
|
| 254 |
+
unreachable = set(self.nodes.keys()) - reachable
|
| 255 |
+
if unreachable:
|
| 256 |
+
errors.append(f"Unreachable nodes from entry node: {', '.join(unreachable)}")
|
| 257 |
+
|
| 258 |
+
return errors
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# ============================================================================
|
| 262 |
+
# Graph Builder Class
|
| 263 |
+
# ============================================================================
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class GraphBuilder:
|
| 267 |
+
"""Builder for constructing research workflow graphs."""
|
| 268 |
+
|
| 269 |
+
def __init__(self) -> None:
|
| 270 |
+
"""Initialize the graph builder."""
|
| 271 |
+
self.graph = ResearchGraph(entry_node="", exit_nodes=[])
|
| 272 |
+
|
| 273 |
+
def add_agent_node(
|
| 274 |
+
self,
|
| 275 |
+
node_id: str,
|
| 276 |
+
agent: "Agent[Any, Any]",
|
| 277 |
+
description: str = "",
|
| 278 |
+
input_transformer: Callable[[Any], Any] | None = None,
|
| 279 |
+
output_transformer: Callable[[Any], Any] | None = None,
|
| 280 |
+
) -> "GraphBuilder":
|
| 281 |
+
"""Add an agent node to the graph.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
node_id: Unique identifier for the node
|
| 285 |
+
agent: Pydantic AI agent to execute
|
| 286 |
+
description: Human-readable description
|
| 287 |
+
input_transformer: Optional input transformation function
|
| 288 |
+
output_transformer: Optional output transformation function
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
Self for method chaining
|
| 292 |
+
"""
|
| 293 |
+
node = AgentNode(
|
| 294 |
+
node_id=node_id,
|
| 295 |
+
agent=agent,
|
| 296 |
+
description=description,
|
| 297 |
+
input_transformer=input_transformer,
|
| 298 |
+
output_transformer=output_transformer,
|
| 299 |
+
)
|
| 300 |
+
self.graph.add_node(node)
|
| 301 |
+
return self
|
| 302 |
+
|
| 303 |
+
def add_state_node(
|
| 304 |
+
self,
|
| 305 |
+
node_id: str,
|
| 306 |
+
state_updater: Callable[["WorkflowState", Any], "WorkflowState"],
|
| 307 |
+
description: str = "",
|
| 308 |
+
state_reader: Callable[["WorkflowState"], Any] | None = None,
|
| 309 |
+
) -> "GraphBuilder":
|
| 310 |
+
"""Add a state node to the graph.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
node_id: Unique identifier for the node
|
| 314 |
+
state_updater: Function to update workflow state
|
| 315 |
+
description: Human-readable description
|
| 316 |
+
state_reader: Optional function to read state
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
Self for method chaining
|
| 320 |
+
"""
|
| 321 |
+
node = StateNode(
|
| 322 |
+
node_id=node_id,
|
| 323 |
+
state_updater=state_updater,
|
| 324 |
+
description=description,
|
| 325 |
+
state_reader=state_reader,
|
| 326 |
+
)
|
| 327 |
+
self.graph.add_node(node)
|
| 328 |
+
return self
|
| 329 |
+
|
| 330 |
+
def add_decision_node(
|
| 331 |
+
self,
|
| 332 |
+
node_id: str,
|
| 333 |
+
decision_function: Callable[[Any], str],
|
| 334 |
+
options: list[str],
|
| 335 |
+
description: str = "",
|
| 336 |
+
) -> "GraphBuilder":
|
| 337 |
+
"""Add a decision node to the graph.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
node_id: Unique identifier for the node
|
| 341 |
+
decision_function: Function that returns next node ID
|
| 342 |
+
options: List of possible next node IDs
|
| 343 |
+
description: Human-readable description
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
Self for method chaining
|
| 347 |
+
"""
|
| 348 |
+
node = DecisionNode(
|
| 349 |
+
node_id=node_id,
|
| 350 |
+
decision_function=decision_function,
|
| 351 |
+
options=options,
|
| 352 |
+
description=description,
|
| 353 |
+
)
|
| 354 |
+
self.graph.add_node(node)
|
| 355 |
+
return self
|
| 356 |
+
|
| 357 |
+
def add_parallel_node(
|
| 358 |
+
self,
|
| 359 |
+
node_id: str,
|
| 360 |
+
parallel_nodes: list[str],
|
| 361 |
+
description: str = "",
|
| 362 |
+
aggregator: Callable[[list[Any]], Any] | None = None,
|
| 363 |
+
) -> "GraphBuilder":
|
| 364 |
+
"""Add a parallel node to the graph.
|
| 365 |
+
|
| 366 |
+
Args:
|
| 367 |
+
node_id: Unique identifier for the node
|
| 368 |
+
parallel_nodes: List of node IDs to run in parallel
|
| 369 |
+
description: Human-readable description
|
| 370 |
+
aggregator: Optional function to aggregate results
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
Self for method chaining
|
| 374 |
+
"""
|
| 375 |
+
node = ParallelNode(
|
| 376 |
+
node_id=node_id,
|
| 377 |
+
parallel_nodes=parallel_nodes,
|
| 378 |
+
description=description,
|
| 379 |
+
aggregator=aggregator,
|
| 380 |
+
)
|
| 381 |
+
self.graph.add_node(node)
|
| 382 |
+
return self
|
| 383 |
+
|
| 384 |
+
def connect_nodes(
|
| 385 |
+
self,
|
| 386 |
+
from_node: str,
|
| 387 |
+
to_node: str,
|
| 388 |
+
condition: Callable[[Any], bool] | None = None,
|
| 389 |
+
condition_description: str = "",
|
| 390 |
+
) -> "GraphBuilder":
|
| 391 |
+
"""Connect two nodes with an edge.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
from_node: Source node ID
|
| 395 |
+
to_node: Target node ID
|
| 396 |
+
condition: Optional condition function
|
| 397 |
+
condition_description: Description of condition (if conditional)
|
| 398 |
+
|
| 399 |
+
Returns:
|
| 400 |
+
Self for method chaining
|
| 401 |
+
"""
|
| 402 |
+
if condition is None:
|
| 403 |
+
edge: GraphEdge = SequentialEdge(from_node=from_node, to_node=to_node)
|
| 404 |
+
else:
|
| 405 |
+
edge = ConditionalEdge(
|
| 406 |
+
from_node=from_node,
|
| 407 |
+
to_node=to_node,
|
| 408 |
+
condition=condition,
|
| 409 |
+
condition_description=condition_description,
|
| 410 |
+
)
|
| 411 |
+
self.graph.add_edge(edge)
|
| 412 |
+
return self
|
| 413 |
+
|
| 414 |
+
def set_entry_node(self, node_id: str) -> "GraphBuilder":
|
| 415 |
+
"""Set the entry node for the graph.
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
node_id: The entry node ID
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
Self for method chaining
|
| 422 |
+
"""
|
| 423 |
+
self.graph.entry_node = node_id
|
| 424 |
+
return self
|
| 425 |
+
|
| 426 |
+
def set_exit_nodes(self, node_ids: list[str]) -> "GraphBuilder":
|
| 427 |
+
"""Set the exit nodes for the graph.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
node_ids: List of exit node IDs
|
| 431 |
+
|
| 432 |
+
Returns:
|
| 433 |
+
Self for method chaining
|
| 434 |
+
"""
|
| 435 |
+
self.graph.exit_nodes = node_ids
|
| 436 |
+
return self
|
| 437 |
+
|
| 438 |
+
def build(self) -> ResearchGraph:
|
| 439 |
+
"""Finalize graph construction and validate.
|
| 440 |
+
|
| 441 |
+
Returns:
|
| 442 |
+
The constructed ResearchGraph
|
| 443 |
+
|
| 444 |
+
Raises:
|
| 445 |
+
ValueError: If graph validation fails
|
| 446 |
+
"""
|
| 447 |
+
errors = self.graph.validate_structure()
|
| 448 |
+
if errors:
|
| 449 |
+
error_msg = "Graph validation failed:\n" + "\n".join(f" - {e}" for e in errors)
|
| 450 |
+
logger.error("Graph validation failed", errors=errors)
|
| 451 |
+
raise ValueError(error_msg)
|
| 452 |
+
|
| 453 |
+
logger.info(
|
| 454 |
+
"Graph built successfully",
|
| 455 |
+
nodes=len(self.graph.nodes),
|
| 456 |
+
edges=sum(len(edges) for edges in self.graph.edges.values()),
|
| 457 |
+
entry_node=self.graph.entry_node,
|
| 458 |
+
exit_nodes=self.graph.exit_nodes,
|
| 459 |
+
)
|
| 460 |
+
return self.graph
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
# ============================================================================
|
| 464 |
+
# Factory Functions
|
| 465 |
+
# ============================================================================
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def create_iterative_graph(
|
| 469 |
+
knowledge_gap_agent: "Agent[Any, Any]",
|
| 470 |
+
tool_selector_agent: "Agent[Any, Any]",
|
| 471 |
+
thinking_agent: "Agent[Any, Any]",
|
| 472 |
+
writer_agent: "Agent[Any, Any]",
|
| 473 |
+
) -> ResearchGraph:
|
| 474 |
+
"""Create a graph for iterative research flow.
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
knowledge_gap_agent: Agent for evaluating knowledge gaps
|
| 478 |
+
tool_selector_agent: Agent for selecting tools
|
| 479 |
+
thinking_agent: Agent for generating observations
|
| 480 |
+
writer_agent: Agent for writing final report
|
| 481 |
+
|
| 482 |
+
Returns:
|
| 483 |
+
Constructed ResearchGraph for iterative research
|
| 484 |
+
"""
|
| 485 |
+
builder = GraphBuilder()
|
| 486 |
+
|
| 487 |
+
# Add nodes
|
| 488 |
+
builder.add_agent_node("thinking", thinking_agent, "Generate observations")
|
| 489 |
+
builder.add_agent_node("knowledge_gap", knowledge_gap_agent, "Evaluate knowledge gaps")
|
| 490 |
+
builder.add_decision_node(
|
| 491 |
+
"continue_decision",
|
| 492 |
+
decision_function=lambda result: "writer"
|
| 493 |
+
if getattr(result, "research_complete", False)
|
| 494 |
+
else "tool_selector",
|
| 495 |
+
options=["tool_selector", "writer"],
|
| 496 |
+
description="Decide whether to continue research or write report",
|
| 497 |
+
)
|
| 498 |
+
builder.add_agent_node("tool_selector", tool_selector_agent, "Select tools to address gap")
|
| 499 |
+
builder.add_state_node(
|
| 500 |
+
"execute_tools",
|
| 501 |
+
state_updater=lambda state,
|
| 502 |
+
tasks: state, # Placeholder - actual execution handled separately
|
| 503 |
+
description="Execute selected tools",
|
| 504 |
+
)
|
| 505 |
+
builder.add_agent_node("writer", writer_agent, "Write final report")
|
| 506 |
+
|
| 507 |
+
# Add edges
|
| 508 |
+
builder.connect_nodes("thinking", "knowledge_gap")
|
| 509 |
+
builder.connect_nodes("knowledge_gap", "continue_decision")
|
| 510 |
+
builder.connect_nodes("continue_decision", "tool_selector")
|
| 511 |
+
builder.connect_nodes("continue_decision", "writer")
|
| 512 |
+
builder.connect_nodes("tool_selector", "execute_tools")
|
| 513 |
+
builder.connect_nodes("execute_tools", "thinking") # Loop back
|
| 514 |
+
|
| 515 |
+
# Set entry and exit
|
| 516 |
+
builder.set_entry_node("thinking")
|
| 517 |
+
builder.set_exit_nodes(["writer"])
|
| 518 |
+
|
| 519 |
+
return builder.build()
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def create_deep_graph(
|
| 523 |
+
planner_agent: "Agent[Any, Any]",
|
| 524 |
+
knowledge_gap_agent: "Agent[Any, Any]",
|
| 525 |
+
tool_selector_agent: "Agent[Any, Any]",
|
| 526 |
+
thinking_agent: "Agent[Any, Any]",
|
| 527 |
+
writer_agent: "Agent[Any, Any]",
|
| 528 |
+
long_writer_agent: "Agent[Any, Any]",
|
| 529 |
+
) -> ResearchGraph:
|
| 530 |
+
"""Create a graph for deep research flow.
|
| 531 |
+
|
| 532 |
+
The graph structure: planner → store_plan → parallel_loops → collect_drafts → synthesizer
|
| 533 |
+
|
| 534 |
+
Args:
|
| 535 |
+
planner_agent: Agent for creating report plan
|
| 536 |
+
knowledge_gap_agent: Agent for evaluating knowledge gaps (not used directly, but needed for iterative flows)
|
| 537 |
+
tool_selector_agent: Agent for selecting tools (not used directly, but needed for iterative flows)
|
| 538 |
+
thinking_agent: Agent for generating observations (not used directly, but needed for iterative flows)
|
| 539 |
+
writer_agent: Agent for writing section reports (not used directly, but needed for iterative flows)
|
| 540 |
+
long_writer_agent: Agent for synthesizing final report
|
| 541 |
+
|
| 542 |
+
Returns:
|
| 543 |
+
Constructed ResearchGraph for deep research
|
| 544 |
+
"""
|
| 545 |
+
from src.utils.models import ReportPlan
|
| 546 |
+
|
| 547 |
+
builder = GraphBuilder()
|
| 548 |
+
|
| 549 |
+
# Add nodes
|
| 550 |
+
# 1. Planner agent - creates report plan
|
| 551 |
+
builder.add_agent_node("planner", planner_agent, "Create report plan with sections")
|
| 552 |
+
|
| 553 |
+
# 2. State node - store report plan in workflow state
|
| 554 |
+
def store_plan(state: "WorkflowState", plan: ReportPlan) -> "WorkflowState":
|
| 555 |
+
"""Store report plan in state for parallel loops to access."""
|
| 556 |
+
# Store plan in a custom attribute (we'll need to extend WorkflowState or use a dict)
|
| 557 |
+
# For now, we'll store it in the context's node_results
|
| 558 |
+
# The actual storage will happen in the graph execution
|
| 559 |
+
return state
|
| 560 |
+
|
| 561 |
+
builder.add_state_node(
|
| 562 |
+
"store_plan",
|
| 563 |
+
state_updater=store_plan,
|
| 564 |
+
description="Store report plan in state",
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
# 3. Parallel node - will execute iterative research flows for each section
|
| 568 |
+
# The actual execution will be handled dynamically in _execute_parallel_node()
|
| 569 |
+
# We use a special node ID that the executor will recognize
|
| 570 |
+
builder.add_parallel_node(
|
| 571 |
+
"parallel_loops",
|
| 572 |
+
parallel_nodes=[], # Will be populated dynamically based on report plan
|
| 573 |
+
description="Execute parallel iterative research loops for each section",
|
| 574 |
+
aggregator=lambda results: results, # Collect all section drafts
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
# 4. State node - collect section drafts into ReportDraft
|
| 578 |
+
def collect_drafts(state: "WorkflowState", section_drafts: list[str]) -> "WorkflowState":
|
| 579 |
+
"""Collect section drafts into state for synthesizer."""
|
| 580 |
+
# Store drafts in state (will be accessed by synthesizer)
|
| 581 |
+
return state
|
| 582 |
+
|
| 583 |
+
builder.add_state_node(
|
| 584 |
+
"collect_drafts",
|
| 585 |
+
state_updater=collect_drafts,
|
| 586 |
+
description="Collect section drafts for synthesis",
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
# 5. Synthesizer agent - creates final report from drafts
|
| 590 |
+
builder.add_agent_node(
|
| 591 |
+
"synthesizer", long_writer_agent, "Synthesize final report from section drafts"
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
# Add edges
|
| 595 |
+
builder.connect_nodes("planner", "store_plan")
|
| 596 |
+
builder.connect_nodes("store_plan", "parallel_loops")
|
| 597 |
+
builder.connect_nodes("parallel_loops", "collect_drafts")
|
| 598 |
+
builder.connect_nodes("collect_drafts", "synthesizer")
|
| 599 |
+
|
| 600 |
+
# Set entry and exit
|
| 601 |
+
builder.set_entry_node("planner")
|
| 602 |
+
builder.set_exit_nodes(["synthesizer"])
|
| 603 |
+
|
| 604 |
+
return builder.build()
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
# No need to rebuild models since we're using Any types
|
| 608 |
+
# The models will work correctly with arbitrary_types_allowed=True
|
src/agent_factory/judges.py
CHANGED
|
@@ -351,6 +351,15 @@ IMPORTANT: Respond with ONLY valid JSON matching this schema:
|
|
| 351 |
)
|
| 352 |
|
| 353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
class MockJudgeHandler:
|
| 355 |
"""
|
| 356 |
Mock JudgeHandler for demo mode without LLM calls.
|
|
|
|
| 351 |
)
|
| 352 |
|
| 353 |
|
| 354 |
+
def create_judge_handler() -> JudgeHandler:
|
| 355 |
+
"""Create a judge handler based on configuration.
|
| 356 |
+
|
| 357 |
+
Returns:
|
| 358 |
+
Configured JudgeHandler instance
|
| 359 |
+
"""
|
| 360 |
+
return JudgeHandler()
|
| 361 |
+
|
| 362 |
+
|
| 363 |
class MockJudgeHandler:
|
| 364 |
"""
|
| 365 |
Mock JudgeHandler for demo mode without LLM calls.
|
src/agents/input_parser.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Input parser agent for analyzing and improving user queries.
|
| 2 |
+
|
| 3 |
+
Determines research mode (iterative vs deep) and extracts key information
|
| 4 |
+
from user queries to improve research quality.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import TYPE_CHECKING, Any, Literal
|
| 8 |
+
|
| 9 |
+
import structlog
|
| 10 |
+
from pydantic_ai import Agent
|
| 11 |
+
|
| 12 |
+
from src.agent_factory.judges import get_model
|
| 13 |
+
from src.utils.exceptions import ConfigurationError, JudgeError
|
| 14 |
+
from src.utils.models import ParsedQuery
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
logger = structlog.get_logger()
|
| 20 |
+
|
| 21 |
+
# System prompt for the input parser agent
|
| 22 |
+
SYSTEM_PROMPT = """
|
| 23 |
+
You are an expert research query analyzer. Your job is to analyze user queries and determine:
|
| 24 |
+
1. Whether the query requires iterative research (single focused question) or deep research (multiple sections/topics)
|
| 25 |
+
2. Improve and refine the query for better research results
|
| 26 |
+
3. Extract key entities (drugs, diseases, targets, companies, etc.)
|
| 27 |
+
4. Extract specific research questions
|
| 28 |
+
|
| 29 |
+
Guidelines for determining research mode:
|
| 30 |
+
- **Iterative mode**: Single focused question, straightforward research goal, can be answered with a focused search loop
|
| 31 |
+
Examples: "What is the mechanism of metformin?", "Find clinical trials for drug X"
|
| 32 |
+
|
| 33 |
+
- **Deep mode**: Complex query requiring multiple sections, comprehensive report, multiple related topics
|
| 34 |
+
Examples: "Write a comprehensive report on diabetes treatment", "Analyze the market for quantum computing"
|
| 35 |
+
Indicators: words like "comprehensive", "report", "sections", "analyze", "market analysis", "overview"
|
| 36 |
+
|
| 37 |
+
Your output must be valid JSON matching the ParsedQuery schema. Always provide:
|
| 38 |
+
- original_query: The exact input query
|
| 39 |
+
- improved_query: A refined, clearer version of the query
|
| 40 |
+
- research_mode: Either "iterative" or "deep"
|
| 41 |
+
- key_entities: List of important entities (drugs, diseases, companies, etc.)
|
| 42 |
+
- research_questions: List of specific questions to answer
|
| 43 |
+
|
| 44 |
+
Only output JSON. Do not output anything else.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class InputParserAgent:
|
| 49 |
+
"""
|
| 50 |
+
Input parser agent that analyzes queries and determines research mode.
|
| 51 |
+
|
| 52 |
+
Uses Pydantic AI to generate structured ParsedQuery output with research
|
| 53 |
+
mode detection, query improvement, and entity extraction.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, model: Any | None = None) -> None:
|
| 57 |
+
"""
|
| 58 |
+
Initialize the input parser agent.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
model: Optional Pydantic AI model. If None, uses config default.
|
| 62 |
+
"""
|
| 63 |
+
self.model = model or get_model()
|
| 64 |
+
self.logger = logger
|
| 65 |
+
|
| 66 |
+
# Initialize Pydantic AI Agent
|
| 67 |
+
self.agent = Agent(
|
| 68 |
+
model=self.model,
|
| 69 |
+
output_type=ParsedQuery,
|
| 70 |
+
system_prompt=SYSTEM_PROMPT,
|
| 71 |
+
retries=3,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
async def parse(self, query: str) -> ParsedQuery:
|
| 75 |
+
"""
|
| 76 |
+
Parse and analyze a user query.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
query: The user's research query
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
ParsedQuery with research mode, improved query, entities, and questions
|
| 83 |
+
|
| 84 |
+
Raises:
|
| 85 |
+
JudgeError: If parsing fails after retries
|
| 86 |
+
ConfigurationError: If agent configuration is invalid
|
| 87 |
+
"""
|
| 88 |
+
self.logger.info("Parsing user query", query=query[:100])
|
| 89 |
+
|
| 90 |
+
user_message = f"QUERY: {query}"
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
# Run the agent
|
| 94 |
+
result = await self.agent.run(user_message)
|
| 95 |
+
parsed_query = result.output
|
| 96 |
+
|
| 97 |
+
# Validate parsed query
|
| 98 |
+
if not parsed_query.original_query:
|
| 99 |
+
self.logger.warning("Parsed query missing original_query", query=query[:100])
|
| 100 |
+
raise JudgeError("Parsed query must have original_query")
|
| 101 |
+
|
| 102 |
+
if not parsed_query.improved_query:
|
| 103 |
+
self.logger.warning("Parsed query missing improved_query", query=query[:100])
|
| 104 |
+
# Use original as fallback
|
| 105 |
+
parsed_query = ParsedQuery(
|
| 106 |
+
original_query=parsed_query.original_query,
|
| 107 |
+
improved_query=parsed_query.original_query,
|
| 108 |
+
research_mode=parsed_query.research_mode,
|
| 109 |
+
key_entities=parsed_query.key_entities,
|
| 110 |
+
research_questions=parsed_query.research_questions,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.logger.info(
|
| 114 |
+
"Query parsed successfully",
|
| 115 |
+
mode=parsed_query.research_mode,
|
| 116 |
+
entities=len(parsed_query.key_entities),
|
| 117 |
+
questions=len(parsed_query.research_questions),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
return parsed_query
|
| 121 |
+
|
| 122 |
+
except Exception as e:
|
| 123 |
+
self.logger.error("Query parsing failed", error=str(e), query=query[:100])
|
| 124 |
+
|
| 125 |
+
# Fallback: return basic parsed query with heuristic mode detection
|
| 126 |
+
if isinstance(e, JudgeError | ConfigurationError):
|
| 127 |
+
raise
|
| 128 |
+
|
| 129 |
+
# Heuristic fallback
|
| 130 |
+
query_lower = query.lower()
|
| 131 |
+
research_mode: Literal["iterative", "deep"] = "iterative"
|
| 132 |
+
if any(
|
| 133 |
+
keyword in query_lower
|
| 134 |
+
for keyword in [
|
| 135 |
+
"comprehensive",
|
| 136 |
+
"report",
|
| 137 |
+
"sections",
|
| 138 |
+
"analyze",
|
| 139 |
+
"analysis",
|
| 140 |
+
"overview",
|
| 141 |
+
"market",
|
| 142 |
+
]
|
| 143 |
+
):
|
| 144 |
+
research_mode = "deep"
|
| 145 |
+
|
| 146 |
+
return ParsedQuery(
|
| 147 |
+
original_query=query,
|
| 148 |
+
improved_query=query,
|
| 149 |
+
research_mode=research_mode,
|
| 150 |
+
key_entities=[],
|
| 151 |
+
research_questions=[],
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def create_input_parser_agent(model: Any | None = None) -> InputParserAgent:
|
| 156 |
+
"""
|
| 157 |
+
Factory function to create an input parser agent.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Configured InputParserAgent instance
|
| 164 |
+
|
| 165 |
+
Raises:
|
| 166 |
+
ConfigurationError: If required API keys are missing
|
| 167 |
+
"""
|
| 168 |
+
try:
|
| 169 |
+
# Get model from settings if not provided
|
| 170 |
+
if model is None:
|
| 171 |
+
model = get_model()
|
| 172 |
+
|
| 173 |
+
# Create and return input parser agent
|
| 174 |
+
return InputParserAgent(model=model)
|
| 175 |
+
|
| 176 |
+
except Exception as e:
|
| 177 |
+
logger.error("Failed to create input parser agent", error=str(e))
|
| 178 |
+
raise ConfigurationError(f"Failed to create input parser agent: {e}") from e
|
src/agents/judge_agent.py
CHANGED
|
@@ -12,7 +12,7 @@ from agent_framework import (
|
|
| 12 |
Role,
|
| 13 |
)
|
| 14 |
|
| 15 |
-
from src.
|
| 16 |
from src.utils.models import Evidence, JudgeAssessment
|
| 17 |
|
| 18 |
|
|
|
|
| 12 |
Role,
|
| 13 |
)
|
| 14 |
|
| 15 |
+
from src.legacy_orchestrator import JudgeHandlerProtocol
|
| 16 |
from src.utils.models import Evidence, JudgeAssessment
|
| 17 |
|
| 18 |
|
src/agents/knowledge_gap.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Knowledge gap agent for evaluating research completeness.
|
| 2 |
+
|
| 3 |
+
Converts the folder/knowledge_gap_agent.py implementation to use Pydantic AI.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import structlog
|
| 10 |
+
from pydantic_ai import Agent
|
| 11 |
+
|
| 12 |
+
from src.agent_factory.judges import get_model
|
| 13 |
+
from src.utils.exceptions import ConfigurationError
|
| 14 |
+
from src.utils.models import KnowledgeGapOutput
|
| 15 |
+
|
| 16 |
+
logger = structlog.get_logger()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# System prompt for the knowledge gap agent
|
| 20 |
+
SYSTEM_PROMPT = f"""
|
| 21 |
+
You are a Research State Evaluator. Today's date is {datetime.now().strftime("%Y-%m-%d")}.
|
| 22 |
+
Your job is to critically analyze the current state of a research report,
|
| 23 |
+
identify what knowledge gaps still exist and determine the best next step to take.
|
| 24 |
+
|
| 25 |
+
You will be given:
|
| 26 |
+
1. The original user query and any relevant background context to the query
|
| 27 |
+
2. A full history of the tasks, actions, findings and thoughts you've made up until this point in the research process
|
| 28 |
+
|
| 29 |
+
Your task is to:
|
| 30 |
+
1. Carefully review the findings and thoughts, particularly from the latest iteration, and assess their completeness in answering the original query
|
| 31 |
+
2. Determine if the findings are sufficiently complete to end the research loop
|
| 32 |
+
3. If not, identify up to 3 knowledge gaps that need to be addressed in sequence in order to continue with research - these should be relevant to the original query
|
| 33 |
+
|
| 34 |
+
Be specific in the gaps you identify and include relevant information as this will be passed onto another agent to process without additional context.
|
| 35 |
+
|
| 36 |
+
Only output JSON. Follow the JSON schema for KnowledgeGapOutput. Do not output anything else.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class KnowledgeGapAgent:
|
| 41 |
+
"""
|
| 42 |
+
Agent that evaluates research state and identifies knowledge gaps.
|
| 43 |
+
|
| 44 |
+
Uses Pydantic AI to generate structured KnowledgeGapOutput indicating
|
| 45 |
+
whether research is complete and what gaps remain.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, model: Any | None = None) -> None:
|
| 49 |
+
"""
|
| 50 |
+
Initialize the knowledge gap agent.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
model: Optional Pydantic AI model. If None, uses config default.
|
| 54 |
+
"""
|
| 55 |
+
self.model = model or get_model()
|
| 56 |
+
self.logger = logger
|
| 57 |
+
|
| 58 |
+
# Initialize Pydantic AI Agent
|
| 59 |
+
self.agent = Agent(
|
| 60 |
+
model=self.model,
|
| 61 |
+
output_type=KnowledgeGapOutput,
|
| 62 |
+
system_prompt=SYSTEM_PROMPT,
|
| 63 |
+
retries=3,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
async def evaluate(
|
| 67 |
+
self,
|
| 68 |
+
query: str,
|
| 69 |
+
background_context: str = "",
|
| 70 |
+
conversation_history: str = "",
|
| 71 |
+
iteration: int = 0,
|
| 72 |
+
time_elapsed_minutes: float = 0.0,
|
| 73 |
+
max_time_minutes: int = 10,
|
| 74 |
+
) -> KnowledgeGapOutput:
|
| 75 |
+
"""
|
| 76 |
+
Evaluate research state and identify knowledge gaps.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
query: The original research query
|
| 80 |
+
background_context: Optional background context
|
| 81 |
+
conversation_history: History of actions, findings, and thoughts
|
| 82 |
+
iteration: Current iteration number
|
| 83 |
+
time_elapsed_minutes: Time elapsed so far
|
| 84 |
+
max_time_minutes: Maximum time allowed
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
KnowledgeGapOutput with research completeness and outstanding gaps
|
| 88 |
+
|
| 89 |
+
Raises:
|
| 90 |
+
JudgeError: If evaluation fails after retries
|
| 91 |
+
"""
|
| 92 |
+
self.logger.info(
|
| 93 |
+
"Evaluating knowledge gaps",
|
| 94 |
+
query=query[:100],
|
| 95 |
+
iteration=iteration,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else ""
|
| 99 |
+
|
| 100 |
+
user_message = f"""
|
| 101 |
+
Current Iteration Number: {iteration}
|
| 102 |
+
Time Elapsed: {time_elapsed_minutes:.2f} minutes of maximum {max_time_minutes} minutes
|
| 103 |
+
|
| 104 |
+
ORIGINAL QUERY:
|
| 105 |
+
{query}
|
| 106 |
+
|
| 107 |
+
{background}
|
| 108 |
+
|
| 109 |
+
HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
|
| 110 |
+
{conversation_history or "No previous actions, findings or thoughts available."}
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
# Run the agent
|
| 115 |
+
result = await self.agent.run(user_message)
|
| 116 |
+
evaluation = result.output
|
| 117 |
+
|
| 118 |
+
self.logger.info(
|
| 119 |
+
"Knowledge gap evaluation complete",
|
| 120 |
+
research_complete=evaluation.research_complete,
|
| 121 |
+
gaps_count=len(evaluation.outstanding_gaps),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
return evaluation
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
self.logger.error("Knowledge gap evaluation failed", error=str(e))
|
| 128 |
+
# Return fallback: research not complete, suggest continuing
|
| 129 |
+
return KnowledgeGapOutput(
|
| 130 |
+
research_complete=False,
|
| 131 |
+
outstanding_gaps=[f"Continue research on: {query}"],
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def create_knowledge_gap_agent(model: Any | None = None) -> KnowledgeGapAgent:
|
| 136 |
+
"""
|
| 137 |
+
Factory function to create a knowledge gap agent.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Configured KnowledgeGapAgent instance
|
| 144 |
+
|
| 145 |
+
Raises:
|
| 146 |
+
ConfigurationError: If required API keys are missing
|
| 147 |
+
"""
|
| 148 |
+
try:
|
| 149 |
+
if model is None:
|
| 150 |
+
model = get_model()
|
| 151 |
+
|
| 152 |
+
return KnowledgeGapAgent(model=model)
|
| 153 |
+
|
| 154 |
+
except Exception as e:
|
| 155 |
+
logger.error("Failed to create knowledge gap agent", error=str(e))
|
| 156 |
+
raise ConfigurationError(f"Failed to create knowledge gap agent: {e}") from e
|
src/agents/long_writer.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Long writer agent for iteratively writing report sections.
|
| 2 |
+
|
| 3 |
+
Converts the folder/long_writer_agent.py implementation to use Pydantic AI.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import structlog
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
from pydantic_ai import Agent
|
| 13 |
+
|
| 14 |
+
from src.agent_factory.judges import get_model
|
| 15 |
+
from src.utils.exceptions import ConfigurationError
|
| 16 |
+
from src.utils.models import ReportDraft
|
| 17 |
+
|
| 18 |
+
logger = structlog.get_logger()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# LongWriterOutput model for structured output
|
| 22 |
+
class LongWriterOutput(BaseModel):
|
| 23 |
+
"""Output from the long writer agent for a single section."""
|
| 24 |
+
|
| 25 |
+
next_section_markdown: str = Field(
|
| 26 |
+
description="The final draft of the next section in markdown format"
|
| 27 |
+
)
|
| 28 |
+
references: list[str] = Field(
|
| 29 |
+
description="A list of URLs and their corresponding reference numbers for the section"
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
model_config = {"frozen": True}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# System prompt for the long writer agent
|
| 36 |
+
SYSTEM_PROMPT = f"""
|
| 37 |
+
You are an expert report writer tasked with iteratively writing each section of a report.
|
| 38 |
+
Today's date is {datetime.now().strftime('%Y-%m-%d')}.
|
| 39 |
+
You will be provided with:
|
| 40 |
+
1. The original research query
|
| 41 |
+
2. A final draft of the report containing the table of contents and all sections written up until this point (in the first iteration there will be no sections written yet)
|
| 42 |
+
3. A first draft of the next section of the report to be written
|
| 43 |
+
|
| 44 |
+
OBJECTIVE:
|
| 45 |
+
1. Write a final draft of the next section of the report with numbered citations in square brackets in the body of the report
|
| 46 |
+
2. Produce a list of references to be appended to the end of the report
|
| 47 |
+
|
| 48 |
+
CITATIONS/REFERENCES:
|
| 49 |
+
The citations should be in numerical order, written in numbered square brackets in the body of the report.
|
| 50 |
+
Separately, a list of all URLs and their corresponding reference numbers will be included at the end of the report.
|
| 51 |
+
Follow the example below for formatting.
|
| 52 |
+
|
| 53 |
+
LongWriterOutput(
|
| 54 |
+
next_section_markdown="The company specializes in IT consulting [1]. It operates in the software services market which is expected to grow at 10% per year [2].",
|
| 55 |
+
references=["[1] https://example.com/first-source-url", "[2] https://example.com/second-source-url"]
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
GUIDELINES:
|
| 59 |
+
- You can reformat and reorganize the flow of the content and headings within a section to flow logically, but DO NOT remove details that were included in the first draft
|
| 60 |
+
- Only remove text from the first draft if it is already mentioned earlier in the report, or if it should be covered in a later section per the table of contents
|
| 61 |
+
- Ensure the heading for the section matches the table of contents
|
| 62 |
+
- Format the final output and references section as markdown
|
| 63 |
+
- Do not include a title for the reference section, just a list of numbered references
|
| 64 |
+
|
| 65 |
+
Only output JSON. Follow the JSON schema for LongWriterOutput. Do not output anything else.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class LongWriterAgent:
|
| 70 |
+
"""
|
| 71 |
+
Agent that iteratively writes report sections with proper citations.
|
| 72 |
+
|
| 73 |
+
Uses Pydantic AI to generate structured LongWriterOutput for each section.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(self, model: Any | None = None) -> None:
|
| 77 |
+
"""
|
| 78 |
+
Initialize the long writer agent.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
model: Optional Pydantic AI model. If None, uses config default.
|
| 82 |
+
"""
|
| 83 |
+
self.model = model or get_model()
|
| 84 |
+
self.logger = logger
|
| 85 |
+
|
| 86 |
+
# Initialize Pydantic AI Agent
|
| 87 |
+
self.agent = Agent(
|
| 88 |
+
model=self.model,
|
| 89 |
+
output_type=LongWriterOutput,
|
| 90 |
+
system_prompt=SYSTEM_PROMPT,
|
| 91 |
+
retries=3,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
async def write_next_section(
|
| 95 |
+
self,
|
| 96 |
+
original_query: str,
|
| 97 |
+
report_draft: str,
|
| 98 |
+
next_section_title: str,
|
| 99 |
+
next_section_draft: str,
|
| 100 |
+
) -> LongWriterOutput:
|
| 101 |
+
"""
|
| 102 |
+
Write the next section of the report.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
original_query: The original research query
|
| 106 |
+
report_draft: Current report draft (all sections written so far)
|
| 107 |
+
next_section_title: Title of the section to write
|
| 108 |
+
next_section_draft: Draft content for the next section
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
LongWriterOutput with formatted section and references
|
| 112 |
+
|
| 113 |
+
Raises:
|
| 114 |
+
ConfigurationError: If writing fails
|
| 115 |
+
"""
|
| 116 |
+
# Input validation
|
| 117 |
+
if not original_query or not original_query.strip():
|
| 118 |
+
self.logger.warning("Empty query provided, using default")
|
| 119 |
+
original_query = "Research query"
|
| 120 |
+
|
| 121 |
+
if not next_section_title or not next_section_title.strip():
|
| 122 |
+
self.logger.warning("Empty section title provided, using default")
|
| 123 |
+
next_section_title = "Section"
|
| 124 |
+
|
| 125 |
+
if next_section_draft is None:
|
| 126 |
+
next_section_draft = ""
|
| 127 |
+
|
| 128 |
+
if report_draft is None:
|
| 129 |
+
report_draft = ""
|
| 130 |
+
|
| 131 |
+
# Truncate very long inputs
|
| 132 |
+
max_draft_length = 30000
|
| 133 |
+
if len(report_draft) > max_draft_length:
|
| 134 |
+
self.logger.warning(
|
| 135 |
+
"Report draft too long, truncating",
|
| 136 |
+
original_length=len(report_draft),
|
| 137 |
+
)
|
| 138 |
+
report_draft = report_draft[:max_draft_length] + "\n\n[Content truncated]"
|
| 139 |
+
|
| 140 |
+
if len(next_section_draft) > max_draft_length:
|
| 141 |
+
self.logger.warning(
|
| 142 |
+
"Section draft too long, truncating",
|
| 143 |
+
original_length=len(next_section_draft),
|
| 144 |
+
)
|
| 145 |
+
next_section_draft = next_section_draft[:max_draft_length] + "\n\n[Content truncated]"
|
| 146 |
+
|
| 147 |
+
self.logger.info(
|
| 148 |
+
"Writing next section",
|
| 149 |
+
section_title=next_section_title,
|
| 150 |
+
query=original_query[:100],
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
user_message = f"""
|
| 154 |
+
<ORIGINAL QUERY>
|
| 155 |
+
{original_query}
|
| 156 |
+
</ORIGINAL QUERY>
|
| 157 |
+
|
| 158 |
+
<CURRENT REPORT DRAFT>
|
| 159 |
+
{report_draft or "No draft yet"}
|
| 160 |
+
</CURRENT REPORT DRAFT>
|
| 161 |
+
|
| 162 |
+
<TITLE OF NEXT SECTION TO WRITE>
|
| 163 |
+
{next_section_title}
|
| 164 |
+
</TITLE OF NEXT SECTION TO WRITE>
|
| 165 |
+
|
| 166 |
+
<DRAFT OF NEXT SECTION>
|
| 167 |
+
{next_section_draft}
|
| 168 |
+
</DRAFT OF NEXT SECTION>
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
# Retry logic for transient failures
|
| 172 |
+
max_retries = 3
|
| 173 |
+
last_exception: Exception | None = None
|
| 174 |
+
|
| 175 |
+
for attempt in range(max_retries):
|
| 176 |
+
try:
|
| 177 |
+
# Run the agent
|
| 178 |
+
result = await self.agent.run(user_message)
|
| 179 |
+
output = result.output
|
| 180 |
+
|
| 181 |
+
# Validate output
|
| 182 |
+
if not output or not isinstance(output, LongWriterOutput):
|
| 183 |
+
raise ValueError("Invalid output format")
|
| 184 |
+
|
| 185 |
+
if not output.next_section_markdown or not output.next_section_markdown.strip():
|
| 186 |
+
self.logger.warning("Empty section generated, using fallback")
|
| 187 |
+
raise ValueError("Empty section generated")
|
| 188 |
+
|
| 189 |
+
self.logger.info(
|
| 190 |
+
"Section written",
|
| 191 |
+
section_title=next_section_title,
|
| 192 |
+
references_count=len(output.references),
|
| 193 |
+
attempt=attempt + 1,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
return output
|
| 197 |
+
|
| 198 |
+
except (TimeoutError, ConnectionError) as e:
|
| 199 |
+
# Transient errors - retry
|
| 200 |
+
last_exception = e
|
| 201 |
+
if attempt < max_retries - 1:
|
| 202 |
+
self.logger.warning(
|
| 203 |
+
"Transient error, retrying",
|
| 204 |
+
error=str(e),
|
| 205 |
+
attempt=attempt + 1,
|
| 206 |
+
max_retries=max_retries,
|
| 207 |
+
)
|
| 208 |
+
continue
|
| 209 |
+
else:
|
| 210 |
+
self.logger.error("Max retries exceeded for transient error", error=str(e))
|
| 211 |
+
break
|
| 212 |
+
|
| 213 |
+
except Exception as e:
|
| 214 |
+
# Non-transient errors - don't retry
|
| 215 |
+
last_exception = e
|
| 216 |
+
self.logger.error(
|
| 217 |
+
"Section writing failed",
|
| 218 |
+
error=str(e),
|
| 219 |
+
error_type=type(e).__name__,
|
| 220 |
+
)
|
| 221 |
+
break
|
| 222 |
+
|
| 223 |
+
# Return fallback section if all attempts failed
|
| 224 |
+
self.logger.error(
|
| 225 |
+
"Section writing failed after all attempts",
|
| 226 |
+
error=str(last_exception) if last_exception else "Unknown error",
|
| 227 |
+
)
|
| 228 |
+
return LongWriterOutput(
|
| 229 |
+
next_section_markdown=f"## {next_section_title}\n\n{next_section_draft}",
|
| 230 |
+
references=[],
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
async def write_report(
|
| 234 |
+
self,
|
| 235 |
+
original_query: str,
|
| 236 |
+
report_title: str,
|
| 237 |
+
report_draft: ReportDraft,
|
| 238 |
+
) -> str:
|
| 239 |
+
"""
|
| 240 |
+
Write the final report by iteratively writing each section.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
original_query: The original research query
|
| 244 |
+
report_title: Title of the report
|
| 245 |
+
report_draft: ReportDraft with all sections
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
Complete markdown report string
|
| 249 |
+
|
| 250 |
+
Raises:
|
| 251 |
+
ConfigurationError: If writing fails
|
| 252 |
+
"""
|
| 253 |
+
# Input validation
|
| 254 |
+
if not original_query or not original_query.strip():
|
| 255 |
+
self.logger.warning("Empty query provided, using default")
|
| 256 |
+
original_query = "Research query"
|
| 257 |
+
|
| 258 |
+
if not report_title or not report_title.strip():
|
| 259 |
+
self.logger.warning("Empty report title provided, using default")
|
| 260 |
+
report_title = "Research Report"
|
| 261 |
+
|
| 262 |
+
if not report_draft or not report_draft.sections:
|
| 263 |
+
self.logger.warning("Empty report draft provided, returning minimal report")
|
| 264 |
+
return f"# {report_title}\n\n## Query\n{original_query}\n\n*No sections available.*"
|
| 265 |
+
|
| 266 |
+
self.logger.info(
|
| 267 |
+
"Writing full report",
|
| 268 |
+
report_title=report_title,
|
| 269 |
+
sections_count=len(report_draft.sections),
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Initialize the final draft with title and table of contents
|
| 273 |
+
final_draft = (
|
| 274 |
+
f"# {report_title}\n\n## Table of Contents\n\n"
|
| 275 |
+
+ "\n".join(
|
| 276 |
+
[
|
| 277 |
+
f"{i+1}. {section.section_title}"
|
| 278 |
+
for i, section in enumerate(report_draft.sections)
|
| 279 |
+
]
|
| 280 |
+
)
|
| 281 |
+
+ "\n\n"
|
| 282 |
+
)
|
| 283 |
+
all_references: list[str] = []
|
| 284 |
+
|
| 285 |
+
for section in report_draft.sections:
|
| 286 |
+
# Write each section
|
| 287 |
+
next_section_output = await self.write_next_section(
|
| 288 |
+
original_query,
|
| 289 |
+
final_draft,
|
| 290 |
+
section.section_title,
|
| 291 |
+
section.section_content,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# Reformat references and update section markdown
|
| 295 |
+
section_markdown, all_references = self._reformat_references(
|
| 296 |
+
next_section_output.next_section_markdown,
|
| 297 |
+
next_section_output.references,
|
| 298 |
+
all_references,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Reformat section headings
|
| 302 |
+
section_markdown = self._reformat_section_headings(section_markdown)
|
| 303 |
+
|
| 304 |
+
# Add to final draft
|
| 305 |
+
final_draft += section_markdown + "\n\n"
|
| 306 |
+
|
| 307 |
+
# Add final references
|
| 308 |
+
final_draft += "## References:\n\n" + " \n".join(all_references)
|
| 309 |
+
|
| 310 |
+
self.logger.info("Full report written", length=len(final_draft))
|
| 311 |
+
|
| 312 |
+
return final_draft
|
| 313 |
+
|
| 314 |
+
def _reformat_references(
|
| 315 |
+
self,
|
| 316 |
+
section_markdown: str,
|
| 317 |
+
section_references: list[str],
|
| 318 |
+
all_references: list[str],
|
| 319 |
+
) -> tuple[str, list[str]]:
|
| 320 |
+
"""
|
| 321 |
+
Reformat references: re-number, de-duplicate, and update markdown.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
section_markdown: Markdown content with inline references [1], [2]
|
| 325 |
+
section_references: List of references for this section
|
| 326 |
+
all_references: Accumulated references from previous sections
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
Tuple of (updated markdown, updated all_references)
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
# Convert reference lists to maps (URL -> ref_num)
|
| 333 |
+
def convert_ref_list_to_map(ref_list: list[str]) -> dict[str, int]:
|
| 334 |
+
ref_map: dict[str, int] = {}
|
| 335 |
+
for ref in ref_list:
|
| 336 |
+
try:
|
| 337 |
+
# Parse "[1] https://example.com" format
|
| 338 |
+
parts = ref.split("]", 1)
|
| 339 |
+
if len(parts) == 2:
|
| 340 |
+
ref_num = int(parts[0].strip("["))
|
| 341 |
+
url = parts[1].strip()
|
| 342 |
+
ref_map[url] = ref_num
|
| 343 |
+
except (ValueError, IndexError):
|
| 344 |
+
logger.warning("Invalid reference format", ref=ref)
|
| 345 |
+
continue
|
| 346 |
+
return ref_map
|
| 347 |
+
|
| 348 |
+
section_ref_map = convert_ref_list_to_map(section_references)
|
| 349 |
+
report_ref_map = convert_ref_list_to_map(all_references)
|
| 350 |
+
section_to_report_ref_map: dict[int, int] = {}
|
| 351 |
+
|
| 352 |
+
report_urls = set(report_ref_map.keys())
|
| 353 |
+
ref_count = max(report_ref_map.values() or [0])
|
| 354 |
+
|
| 355 |
+
# Map section references to report references
|
| 356 |
+
for url, section_ref_num in section_ref_map.items():
|
| 357 |
+
if url in report_urls:
|
| 358 |
+
# URL already exists - reuse its reference number
|
| 359 |
+
section_to_report_ref_map[section_ref_num] = report_ref_map[url]
|
| 360 |
+
else:
|
| 361 |
+
# New URL - assign next reference number
|
| 362 |
+
ref_count += 1
|
| 363 |
+
section_to_report_ref_map[section_ref_num] = ref_count
|
| 364 |
+
all_references.append(f"[{ref_count}] {url}")
|
| 365 |
+
|
| 366 |
+
# Replace reference numbers in markdown
|
| 367 |
+
def replace_reference(match: re.Match[str]) -> str:
|
| 368 |
+
ref_num = int(match.group(1))
|
| 369 |
+
mapped_ref_num = section_to_report_ref_map.get(ref_num)
|
| 370 |
+
if mapped_ref_num:
|
| 371 |
+
return f"[{mapped_ref_num}]"
|
| 372 |
+
return ""
|
| 373 |
+
|
| 374 |
+
updated_markdown = re.sub(r"\[(\d+)\]", replace_reference, section_markdown)
|
| 375 |
+
|
| 376 |
+
return updated_markdown, all_references
|
| 377 |
+
|
| 378 |
+
def _reformat_section_headings(self, section_markdown: str) -> str:
|
| 379 |
+
"""
|
| 380 |
+
Reformat section headings to be consistent (level-2 for main heading).
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
section_markdown: Markdown content with headings
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
Updated markdown with adjusted heading levels
|
| 387 |
+
"""
|
| 388 |
+
if not section_markdown.strip():
|
| 389 |
+
return section_markdown
|
| 390 |
+
|
| 391 |
+
# Find first heading level
|
| 392 |
+
first_heading_match = re.search(r"^(#+)\s", section_markdown, re.MULTILINE)
|
| 393 |
+
if not first_heading_match:
|
| 394 |
+
return section_markdown
|
| 395 |
+
|
| 396 |
+
# Calculate level adjustment needed (target is level 2)
|
| 397 |
+
first_heading_level = len(first_heading_match.group(1))
|
| 398 |
+
level_adjustment = 2 - first_heading_level
|
| 399 |
+
|
| 400 |
+
def adjust_heading_level(match: re.Match[str]) -> str:
|
| 401 |
+
hashes = match.group(1)
|
| 402 |
+
content = match.group(2)
|
| 403 |
+
new_level = max(2, len(hashes) + level_adjustment)
|
| 404 |
+
return "#" * new_level + " " + content
|
| 405 |
+
|
| 406 |
+
# Apply heading adjustment
|
| 407 |
+
return re.sub(r"^(#+)\s(.+)$", adjust_heading_level, section_markdown, flags=re.MULTILINE)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def create_long_writer_agent(model: Any | None = None) -> LongWriterAgent:
|
| 411 |
+
"""
|
| 412 |
+
Factory function to create a long writer agent.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 416 |
+
|
| 417 |
+
Returns:
|
| 418 |
+
Configured LongWriterAgent instance
|
| 419 |
+
|
| 420 |
+
Raises:
|
| 421 |
+
ConfigurationError: If required API keys are missing
|
| 422 |
+
"""
|
| 423 |
+
try:
|
| 424 |
+
if model is None:
|
| 425 |
+
model = get_model()
|
| 426 |
+
|
| 427 |
+
return LongWriterAgent(model=model)
|
| 428 |
+
|
| 429 |
+
except Exception as e:
|
| 430 |
+
logger.error("Failed to create long writer agent", error=str(e))
|
| 431 |
+
raise ConfigurationError(f"Failed to create long writer agent: {e}") from e
|
src/agents/proofreader.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Proofreader agent for finalizing report drafts.
|
| 2 |
+
|
| 3 |
+
Converts the folder/proofreader_agent.py implementation to use Pydantic AI.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import structlog
|
| 10 |
+
from pydantic_ai import Agent
|
| 11 |
+
|
| 12 |
+
from src.agent_factory.judges import get_model
|
| 13 |
+
from src.utils.exceptions import ConfigurationError
|
| 14 |
+
from src.utils.models import ReportDraft
|
| 15 |
+
|
| 16 |
+
logger = structlog.get_logger()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# System prompt for the proofreader agent
|
| 20 |
+
SYSTEM_PROMPT = f"""
|
| 21 |
+
You are a research expert who proofreads and edits research reports.
|
| 22 |
+
Today's date is {datetime.now().strftime("%Y-%m-%d")}.
|
| 23 |
+
|
| 24 |
+
You are given:
|
| 25 |
+
1. The original query topic for the report
|
| 26 |
+
2. A first draft of the report in ReportDraft format containing each section in sequence
|
| 27 |
+
|
| 28 |
+
Your task is to:
|
| 29 |
+
1. **Combine sections:** Concatenate the sections into a single string
|
| 30 |
+
2. **Add section titles:** Add the section titles to the beginning of each section in markdown format, as well as a main title for the report
|
| 31 |
+
3. **De-duplicate:** Remove duplicate content across sections to avoid repetition
|
| 32 |
+
4. **Remove irrelevant sections:** If any sections or sub-sections are completely irrelevant to the query, remove them
|
| 33 |
+
5. **Refine wording:** Edit the wording of the report to be polished, concise and punchy, but **without eliminating any detail** or large chunks of text
|
| 34 |
+
6. **Add a summary:** Add a short report summary / outline to the beginning of the report to provide an overview of the sections and what is discussed
|
| 35 |
+
7. **Preserve sources:** Preserve all sources / references - move the long list of references to the end of the report
|
| 36 |
+
8. **Update reference numbers:** Continue to include reference numbers in square brackets ([1], [2], [3], etc.) in the main body of the report, but update the numbering to match the new order of references at the end of the report
|
| 37 |
+
9. **Output final report:** Output the final report in markdown format (do not wrap it in a code block)
|
| 38 |
+
|
| 39 |
+
Guidelines:
|
| 40 |
+
- Do not add any new facts or data to the report
|
| 41 |
+
- Do not remove any content from the report unless it is very clearly wrong, contradictory or irrelevant
|
| 42 |
+
- Remove or reformat any redundant or excessive headings, and ensure that the final nesting of heading levels is correct
|
| 43 |
+
- Ensure that the final report flows well and has a logical structure
|
| 44 |
+
- Include all sources and references that are present in the final report
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ProofreaderAgent:
|
| 49 |
+
"""
|
| 50 |
+
Agent that proofreads and finalizes report drafts.
|
| 51 |
+
|
| 52 |
+
Uses Pydantic AI to generate polished markdown reports from draft sections.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, model: Any | None = None) -> None:
|
| 56 |
+
"""
|
| 57 |
+
Initialize the proofreader agent.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
model: Optional Pydantic AI model. If None, uses config default.
|
| 61 |
+
"""
|
| 62 |
+
self.model = model or get_model()
|
| 63 |
+
self.logger = logger
|
| 64 |
+
|
| 65 |
+
# Initialize Pydantic AI Agent (no structured output - returns markdown text)
|
| 66 |
+
self.agent = Agent(
|
| 67 |
+
model=self.model,
|
| 68 |
+
system_prompt=SYSTEM_PROMPT,
|
| 69 |
+
retries=3,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
async def proofread(
|
| 73 |
+
self,
|
| 74 |
+
query: str,
|
| 75 |
+
report_draft: ReportDraft,
|
| 76 |
+
) -> str:
|
| 77 |
+
"""
|
| 78 |
+
Proofread and finalize a report draft.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
query: The original research query
|
| 82 |
+
report_draft: ReportDraft with all sections
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Final polished markdown report string
|
| 86 |
+
|
| 87 |
+
Raises:
|
| 88 |
+
ConfigurationError: If proofreading fails
|
| 89 |
+
"""
|
| 90 |
+
# Input validation
|
| 91 |
+
if not query or not query.strip():
|
| 92 |
+
self.logger.warning("Empty query provided, using default")
|
| 93 |
+
query = "Research query"
|
| 94 |
+
|
| 95 |
+
if not report_draft or not report_draft.sections:
|
| 96 |
+
self.logger.warning("Empty report draft provided, returning minimal report")
|
| 97 |
+
return f"# Research Report\n\n## Query\n{query}\n\n*No sections available.*"
|
| 98 |
+
|
| 99 |
+
# Validate section structure
|
| 100 |
+
valid_sections = []
|
| 101 |
+
for section in report_draft.sections:
|
| 102 |
+
if section.section_title and section.section_title.strip():
|
| 103 |
+
valid_sections.append(section)
|
| 104 |
+
else:
|
| 105 |
+
self.logger.warning("Skipping section with empty title")
|
| 106 |
+
|
| 107 |
+
if not valid_sections:
|
| 108 |
+
self.logger.warning("No valid sections in draft, returning minimal report")
|
| 109 |
+
return f"# Research Report\n\n## Query\n{query}\n\n*No valid sections available.*"
|
| 110 |
+
|
| 111 |
+
self.logger.info(
|
| 112 |
+
"Proofreading report",
|
| 113 |
+
query=query[:100],
|
| 114 |
+
sections_count=len(valid_sections),
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Create validated draft
|
| 118 |
+
validated_draft = ReportDraft(sections=valid_sections)
|
| 119 |
+
|
| 120 |
+
user_message = f"""
|
| 121 |
+
QUERY:
|
| 122 |
+
{query}
|
| 123 |
+
|
| 124 |
+
REPORT DRAFT:
|
| 125 |
+
{validated_draft.model_dump_json()}
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
# Retry logic for transient failures
|
| 129 |
+
max_retries = 3
|
| 130 |
+
last_exception: Exception | None = None
|
| 131 |
+
|
| 132 |
+
for attempt in range(max_retries):
|
| 133 |
+
try:
|
| 134 |
+
# Run the agent
|
| 135 |
+
result = await self.agent.run(user_message)
|
| 136 |
+
final_report = result.output
|
| 137 |
+
|
| 138 |
+
# Validate output
|
| 139 |
+
if not final_report or not final_report.strip():
|
| 140 |
+
self.logger.warning("Empty report generated, using fallback")
|
| 141 |
+
raise ValueError("Empty report generated")
|
| 142 |
+
|
| 143 |
+
self.logger.info("Report proofread", length=len(final_report), attempt=attempt + 1)
|
| 144 |
+
|
| 145 |
+
return final_report
|
| 146 |
+
|
| 147 |
+
except (TimeoutError, ConnectionError) as e:
|
| 148 |
+
# Transient errors - retry
|
| 149 |
+
last_exception = e
|
| 150 |
+
if attempt < max_retries - 1:
|
| 151 |
+
self.logger.warning(
|
| 152 |
+
"Transient error, retrying",
|
| 153 |
+
error=str(e),
|
| 154 |
+
attempt=attempt + 1,
|
| 155 |
+
max_retries=max_retries,
|
| 156 |
+
)
|
| 157 |
+
continue
|
| 158 |
+
else:
|
| 159 |
+
self.logger.error("Max retries exceeded for transient error", error=str(e))
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
except Exception as e:
|
| 163 |
+
# Non-transient errors - don't retry
|
| 164 |
+
last_exception = e
|
| 165 |
+
self.logger.error(
|
| 166 |
+
"Proofreading failed",
|
| 167 |
+
error=str(e),
|
| 168 |
+
error_type=type(e).__name__,
|
| 169 |
+
)
|
| 170 |
+
break
|
| 171 |
+
|
| 172 |
+
# Return fallback: combine sections manually
|
| 173 |
+
self.logger.error(
|
| 174 |
+
"Proofreading failed after all attempts",
|
| 175 |
+
error=str(last_exception) if last_exception else "Unknown error",
|
| 176 |
+
)
|
| 177 |
+
sections = [
|
| 178 |
+
f"## {section.section_title}\n\n{section.section_content or 'Content unavailable.'}"
|
| 179 |
+
for section in valid_sections
|
| 180 |
+
]
|
| 181 |
+
return f"# Research Report\n\n## Query\n{query}\n\n" + "\n\n".join(sections)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def create_proofreader_agent(model: Any | None = None) -> ProofreaderAgent:
|
| 185 |
+
"""
|
| 186 |
+
Factory function to create a proofreader agent.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Configured ProofreaderAgent instance
|
| 193 |
+
|
| 194 |
+
Raises:
|
| 195 |
+
ConfigurationError: If required API keys are missing
|
| 196 |
+
"""
|
| 197 |
+
try:
|
| 198 |
+
if model is None:
|
| 199 |
+
model = get_model()
|
| 200 |
+
|
| 201 |
+
return ProofreaderAgent(model=model)
|
| 202 |
+
|
| 203 |
+
except Exception as e:
|
| 204 |
+
logger.error("Failed to create proofreader agent", error=str(e))
|
| 205 |
+
raise ConfigurationError(f"Failed to create proofreader agent: {e}") from e
|
src/agents/search_agent.py
CHANGED
|
@@ -10,7 +10,7 @@ from agent_framework import (
|
|
| 10 |
Role,
|
| 11 |
)
|
| 12 |
|
| 13 |
-
from src.
|
| 14 |
from src.utils.models import Citation, Evidence, SearchResult
|
| 15 |
|
| 16 |
if TYPE_CHECKING:
|
|
|
|
| 10 |
Role,
|
| 11 |
)
|
| 12 |
|
| 13 |
+
from src.legacy_orchestrator import SearchHandlerProtocol
|
| 14 |
from src.utils.models import Citation, Evidence, SearchResult
|
| 15 |
|
| 16 |
if TYPE_CHECKING:
|
src/agents/state.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
| 1 |
"""Thread-safe state management for Magentic agents.
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
| 5 |
"""
|
| 6 |
|
|
|
|
| 7 |
from contextvars import ContextVar
|
| 8 |
from typing import TYPE_CHECKING, Any
|
| 9 |
|
|
@@ -15,8 +17,20 @@ if TYPE_CHECKING:
|
|
| 15 |
from src.services.embeddings import EmbeddingService
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
class MagenticState(BaseModel):
|
| 19 |
-
"""Mutable state for a Magentic workflow session.
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
evidence: list[Evidence] = Field(default_factory=list)
|
| 22 |
# Type as Any to avoid circular imports/runtime resolution issues
|
|
@@ -75,14 +89,22 @@ _magentic_state_var: ContextVar[MagenticState | None] = ContextVar("magentic_sta
|
|
| 75 |
|
| 76 |
|
| 77 |
def init_magentic_state(embedding_service: "EmbeddingService | None" = None) -> MagenticState:
|
| 78 |
-
"""Initialize a new state for the current context.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
state = MagenticState(embedding_service=embedding_service)
|
| 80 |
_magentic_state_var.set(state)
|
| 81 |
return state
|
| 82 |
|
| 83 |
|
| 84 |
def get_magentic_state() -> MagenticState:
|
| 85 |
-
"""Get the current state. Raises RuntimeError if not initialized.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
state = _magentic_state_var.get()
|
| 87 |
if state is None:
|
| 88 |
# Auto-initialize if missing (e.g. during tests or simple scripts)
|
|
|
|
| 1 |
"""Thread-safe state management for Magentic agents.
|
| 2 |
|
| 3 |
+
DEPRECATED: This module is deprecated. Use src.middleware.state_machine instead.
|
| 4 |
+
|
| 5 |
+
This file is kept for backward compatibility and will be removed in a future version.
|
| 6 |
"""
|
| 7 |
|
| 8 |
+
import warnings
|
| 9 |
from contextvars import ContextVar
|
| 10 |
from typing import TYPE_CHECKING, Any
|
| 11 |
|
|
|
|
| 17 |
from src.services.embeddings import EmbeddingService
|
| 18 |
|
| 19 |
|
| 20 |
+
def _deprecation_warning() -> None:
|
| 21 |
+
"""Emit deprecation warning for this module."""
|
| 22 |
+
warnings.warn(
|
| 23 |
+
"src.agents.state is deprecated. Use src.middleware.state_machine instead.",
|
| 24 |
+
DeprecationWarning,
|
| 25 |
+
stacklevel=3,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
class MagenticState(BaseModel):
|
| 30 |
+
"""Mutable state for a Magentic workflow session.
|
| 31 |
+
|
| 32 |
+
DEPRECATED: Use WorkflowState from src.middleware.state_machine instead.
|
| 33 |
+
"""
|
| 34 |
|
| 35 |
evidence: list[Evidence] = Field(default_factory=list)
|
| 36 |
# Type as Any to avoid circular imports/runtime resolution issues
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
def init_magentic_state(embedding_service: "EmbeddingService | None" = None) -> MagenticState:
|
| 92 |
+
"""Initialize a new state for the current context.
|
| 93 |
+
|
| 94 |
+
DEPRECATED: Use init_workflow_state from src.middleware.state_machine instead.
|
| 95 |
+
"""
|
| 96 |
+
_deprecation_warning()
|
| 97 |
state = MagenticState(embedding_service=embedding_service)
|
| 98 |
_magentic_state_var.set(state)
|
| 99 |
return state
|
| 100 |
|
| 101 |
|
| 102 |
def get_magentic_state() -> MagenticState:
|
| 103 |
+
"""Get the current state. Raises RuntimeError if not initialized.
|
| 104 |
+
|
| 105 |
+
DEPRECATED: Use get_workflow_state from src.middleware.state_machine instead.
|
| 106 |
+
"""
|
| 107 |
+
_deprecation_warning()
|
| 108 |
state = _magentic_state_var.get()
|
| 109 |
if state is None:
|
| 110 |
# Auto-initialize if missing (e.g. during tests or simple scripts)
|
src/agents/thinking.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Thinking agent for generating observations and reflections.
|
| 2 |
+
|
| 3 |
+
Converts the folder/thinking_agent.py implementation to use Pydantic AI.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import structlog
|
| 10 |
+
from pydantic_ai import Agent
|
| 11 |
+
|
| 12 |
+
from src.agent_factory.judges import get_model
|
| 13 |
+
from src.utils.exceptions import ConfigurationError
|
| 14 |
+
|
| 15 |
+
logger = structlog.get_logger()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# System prompt for the thinking agent
|
| 19 |
+
SYSTEM_PROMPT = f"""
|
| 20 |
+
You are a research expert who is managing a research process in iterations. Today's date is {datetime.now().strftime("%Y-%m-%d")}.
|
| 21 |
+
|
| 22 |
+
You are given:
|
| 23 |
+
1. The original research query along with some supporting background context
|
| 24 |
+
2. A history of the tasks, actions, findings and thoughts you've made up until this point in the research process (on iteration 1 you will be at the start of the research process, so this will be empty)
|
| 25 |
+
|
| 26 |
+
Your objective is to reflect on the research process so far and share your latest thoughts.
|
| 27 |
+
|
| 28 |
+
Specifically, your thoughts should include reflections on questions such as:
|
| 29 |
+
- What have you learned from the last iteration?
|
| 30 |
+
- What new areas would you like to explore next, or existing topics you'd like to go deeper into?
|
| 31 |
+
- Were you able to retrieve the information you were looking for in the last iteration?
|
| 32 |
+
- If not, should we change our approach or move to the next topic?
|
| 33 |
+
- Is there any info that is contradictory or conflicting?
|
| 34 |
+
|
| 35 |
+
Guidelines:
|
| 36 |
+
- Share your stream of consciousness on the above questions as raw text
|
| 37 |
+
- Keep your response concise and informal
|
| 38 |
+
- Focus most of your thoughts on the most recent iteration and how that influences this next iteration
|
| 39 |
+
- Our aim is to do very deep and thorough research - bear this in mind when reflecting on the research process
|
| 40 |
+
- DO NOT produce a draft of the final report. This is not your job.
|
| 41 |
+
- If this is the first iteration (i.e. no data from prior iterations), provide thoughts on what info we need to gather in the first iteration to get started
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ThinkingAgent:
|
| 46 |
+
"""
|
| 47 |
+
Agent that generates observations and reflections on the research process.
|
| 48 |
+
|
| 49 |
+
Uses Pydantic AI to generate unstructured text observations about
|
| 50 |
+
the current state of research and next steps.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, model: Any | None = None) -> None:
|
| 54 |
+
"""
|
| 55 |
+
Initialize the thinking agent.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
model: Optional Pydantic AI model. If None, uses config default.
|
| 59 |
+
"""
|
| 60 |
+
self.model = model or get_model()
|
| 61 |
+
self.logger = logger
|
| 62 |
+
|
| 63 |
+
# Initialize Pydantic AI Agent (no structured output - returns text)
|
| 64 |
+
self.agent = Agent(
|
| 65 |
+
model=self.model,
|
| 66 |
+
system_prompt=SYSTEM_PROMPT,
|
| 67 |
+
retries=3,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
async def generate_observations(
|
| 71 |
+
self,
|
| 72 |
+
query: str,
|
| 73 |
+
background_context: str = "",
|
| 74 |
+
conversation_history: str = "",
|
| 75 |
+
iteration: int = 1,
|
| 76 |
+
) -> str:
|
| 77 |
+
"""
|
| 78 |
+
Generate observations about the research process.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
query: The original research query
|
| 82 |
+
background_context: Optional background context
|
| 83 |
+
conversation_history: History of actions, findings, and thoughts
|
| 84 |
+
iteration: Current iteration number
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
String containing observations and reflections
|
| 88 |
+
|
| 89 |
+
Raises:
|
| 90 |
+
ConfigurationError: If generation fails
|
| 91 |
+
"""
|
| 92 |
+
self.logger.info(
|
| 93 |
+
"Generating observations",
|
| 94 |
+
query=query[:100],
|
| 95 |
+
iteration=iteration,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else ""
|
| 99 |
+
|
| 100 |
+
user_message = f"""
|
| 101 |
+
You are starting iteration {iteration} of your research process.
|
| 102 |
+
|
| 103 |
+
ORIGINAL QUERY:
|
| 104 |
+
{query}
|
| 105 |
+
|
| 106 |
+
{background}
|
| 107 |
+
|
| 108 |
+
HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
|
| 109 |
+
{conversation_history or "No previous actions, findings or thoughts available."}
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
try:
|
| 113 |
+
# Run the agent
|
| 114 |
+
result = await self.agent.run(user_message)
|
| 115 |
+
observations = result.output
|
| 116 |
+
|
| 117 |
+
self.logger.info("Observations generated", length=len(observations))
|
| 118 |
+
|
| 119 |
+
return observations
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
self.logger.error("Observation generation failed", error=str(e))
|
| 123 |
+
# Return fallback observations
|
| 124 |
+
return f"Starting iteration {iteration}. Need to gather information about: {query}"
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def create_thinking_agent(model: Any | None = None) -> ThinkingAgent:
|
| 128 |
+
"""
|
| 129 |
+
Factory function to create a thinking agent.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
Configured ThinkingAgent instance
|
| 136 |
+
|
| 137 |
+
Raises:
|
| 138 |
+
ConfigurationError: If required API keys are missing
|
| 139 |
+
"""
|
| 140 |
+
try:
|
| 141 |
+
if model is None:
|
| 142 |
+
model = get_model()
|
| 143 |
+
|
| 144 |
+
return ThinkingAgent(model=model)
|
| 145 |
+
|
| 146 |
+
except Exception as e:
|
| 147 |
+
logger.error("Failed to create thinking agent", error=str(e))
|
| 148 |
+
raise ConfigurationError(f"Failed to create thinking agent: {e}") from e
|
src/agents/tool_selector.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tool selector agent for choosing which tools to use for knowledge gaps.
|
| 2 |
+
|
| 3 |
+
Converts the folder/tool_selector_agent.py implementation to use Pydantic AI.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import structlog
|
| 10 |
+
from pydantic_ai import Agent
|
| 11 |
+
|
| 12 |
+
from src.agent_factory.judges import get_model
|
| 13 |
+
from src.utils.exceptions import ConfigurationError
|
| 14 |
+
from src.utils.models import AgentSelectionPlan
|
| 15 |
+
|
| 16 |
+
logger = structlog.get_logger()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# System prompt for the tool selector agent
|
| 20 |
+
SYSTEM_PROMPT = f"""
|
| 21 |
+
You are a Tool Selector responsible for determining which specialized agents should address a knowledge gap in a research project.
|
| 22 |
+
Today's date is {datetime.now().strftime("%Y-%m-%d")}.
|
| 23 |
+
|
| 24 |
+
You will be given:
|
| 25 |
+
1. The original user query
|
| 26 |
+
2. A knowledge gap identified in the research
|
| 27 |
+
3. A full history of the tasks, actions, findings and thoughts you've made up until this point in the research process
|
| 28 |
+
|
| 29 |
+
Your task is to decide:
|
| 30 |
+
1. Which specialized agents are best suited to address the gap
|
| 31 |
+
2. What specific queries should be given to the agents (keep this short - 3-6 words)
|
| 32 |
+
|
| 33 |
+
Available specialized agents:
|
| 34 |
+
- WebSearchAgent: General web search for broad topics (can be called multiple times with different queries)
|
| 35 |
+
- SiteCrawlerAgent: Crawl the pages of a specific website to retrieve information about it - use this if you want to find out something about a particular company, entity or product
|
| 36 |
+
- RAGAgent: Semantic search within previously collected evidence - use when you need to find information from evidence already gathered in this research session. Best for finding connections, summarizing collected evidence, or retrieving specific details from earlier findings.
|
| 37 |
+
|
| 38 |
+
Guidelines:
|
| 39 |
+
- Aim to call at most 3 agents at a time in your final output
|
| 40 |
+
- You can list the WebSearchAgent multiple times with different queries if needed to cover the full scope of the knowledge gap
|
| 41 |
+
- Be specific and concise (3-6 words) with the agent queries - they should target exactly what information is needed
|
| 42 |
+
- If you know the website or domain name of an entity being researched, always include it in the query
|
| 43 |
+
- Use RAGAgent when: (1) You need to search within evidence already collected, (2) You want to find connections between different findings, (3) You need to retrieve specific details from earlier research iterations
|
| 44 |
+
- Use WebSearchAgent or SiteCrawlerAgent when: (1) You need fresh information from the web, (2) You're starting a new research direction, (3) You need information not yet in the collected evidence
|
| 45 |
+
- If a gap doesn't clearly match any agent's capability, default to the WebSearchAgent
|
| 46 |
+
- Use the history of actions / tool calls as a guide - try not to repeat yourself if an approach didn't work previously
|
| 47 |
+
|
| 48 |
+
Only output JSON. Follow the JSON schema for AgentSelectionPlan. Do not output anything else.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ToolSelectorAgent:
|
| 53 |
+
"""
|
| 54 |
+
Agent that selects appropriate tools to address knowledge gaps.
|
| 55 |
+
|
| 56 |
+
Uses Pydantic AI to generate structured AgentSelectionPlan with
|
| 57 |
+
specific tasks for web search and crawl agents.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, model: Any | None = None) -> None:
|
| 61 |
+
"""
|
| 62 |
+
Initialize the tool selector agent.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
model: Optional Pydantic AI model. If None, uses config default.
|
| 66 |
+
"""
|
| 67 |
+
self.model = model or get_model()
|
| 68 |
+
self.logger = logger
|
| 69 |
+
|
| 70 |
+
# Initialize Pydantic AI Agent
|
| 71 |
+
self.agent = Agent(
|
| 72 |
+
model=self.model,
|
| 73 |
+
output_type=AgentSelectionPlan,
|
| 74 |
+
system_prompt=SYSTEM_PROMPT,
|
| 75 |
+
retries=3,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
async def select_tools(
|
| 79 |
+
self,
|
| 80 |
+
gap: str,
|
| 81 |
+
query: str,
|
| 82 |
+
background_context: str = "",
|
| 83 |
+
conversation_history: str = "",
|
| 84 |
+
) -> AgentSelectionPlan:
|
| 85 |
+
"""
|
| 86 |
+
Select tools to address a knowledge gap.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
gap: The knowledge gap to address
|
| 90 |
+
query: The original research query
|
| 91 |
+
background_context: Optional background context
|
| 92 |
+
conversation_history: History of actions, findings, and thoughts
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
AgentSelectionPlan with tasks for selected agents
|
| 96 |
+
|
| 97 |
+
Raises:
|
| 98 |
+
ConfigurationError: If selection fails
|
| 99 |
+
"""
|
| 100 |
+
self.logger.info("Selecting tools for gap", gap=gap[:100], query=query[:100])
|
| 101 |
+
|
| 102 |
+
background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else ""
|
| 103 |
+
|
| 104 |
+
user_message = f"""
|
| 105 |
+
ORIGINAL QUERY:
|
| 106 |
+
{query}
|
| 107 |
+
|
| 108 |
+
KNOWLEDGE GAP TO ADDRESS:
|
| 109 |
+
{gap}
|
| 110 |
+
|
| 111 |
+
{background}
|
| 112 |
+
|
| 113 |
+
HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
|
| 114 |
+
{conversation_history or "No previous actions, findings or thoughts available."}
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
# Run the agent
|
| 119 |
+
result = await self.agent.run(user_message)
|
| 120 |
+
selection_plan = result.output
|
| 121 |
+
|
| 122 |
+
self.logger.info(
|
| 123 |
+
"Tool selection complete",
|
| 124 |
+
tasks_count=len(selection_plan.tasks),
|
| 125 |
+
agents=[task.agent for task in selection_plan.tasks],
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
return selection_plan
|
| 129 |
+
|
| 130 |
+
except Exception as e:
|
| 131 |
+
self.logger.error("Tool selection failed", error=str(e))
|
| 132 |
+
# Return fallback: use web search
|
| 133 |
+
from src.utils.models import AgentTask
|
| 134 |
+
|
| 135 |
+
return AgentSelectionPlan(
|
| 136 |
+
tasks=[
|
| 137 |
+
AgentTask(
|
| 138 |
+
gap=gap,
|
| 139 |
+
agent="WebSearchAgent",
|
| 140 |
+
query=gap[:50], # Use gap as query
|
| 141 |
+
entity_website=None,
|
| 142 |
+
)
|
| 143 |
+
]
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def create_tool_selector_agent(model: Any | None = None) -> ToolSelectorAgent:
|
| 148 |
+
"""
|
| 149 |
+
Factory function to create a tool selector agent.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Configured ToolSelectorAgent instance
|
| 156 |
+
|
| 157 |
+
Raises:
|
| 158 |
+
ConfigurationError: If required API keys are missing
|
| 159 |
+
"""
|
| 160 |
+
try:
|
| 161 |
+
if model is None:
|
| 162 |
+
model = get_model()
|
| 163 |
+
|
| 164 |
+
return ToolSelectorAgent(model=model)
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error("Failed to create tool selector agent", error=str(e))
|
| 168 |
+
raise ConfigurationError(f"Failed to create tool selector agent: {e}") from e
|
src/agents/writer.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Writer agent for generating final reports from findings.
|
| 2 |
+
|
| 3 |
+
Converts the folder/writer_agent.py implementation to use Pydantic AI.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import structlog
|
| 10 |
+
from pydantic_ai import Agent
|
| 11 |
+
|
| 12 |
+
from src.agent_factory.judges import get_model
|
| 13 |
+
from src.utils.exceptions import ConfigurationError
|
| 14 |
+
|
| 15 |
+
logger = structlog.get_logger()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# System prompt for the writer agent
|
| 19 |
+
SYSTEM_PROMPT = f"""
|
| 20 |
+
You are a senior researcher tasked with comprehensively answering a research query.
|
| 21 |
+
Today's date is {datetime.now().strftime('%Y-%m-%d')}.
|
| 22 |
+
You will be provided with the original query along with research findings put together by a research assistant.
|
| 23 |
+
Your objective is to generate the final response in markdown format.
|
| 24 |
+
The response should be as lengthy and detailed as possible with the information provided, focusing on answering the original query.
|
| 25 |
+
In your final output, include references to the source URLs for all information and data gathered.
|
| 26 |
+
This should be formatted in the form of a numbered square bracket next to the relevant information,
|
| 27 |
+
followed by a list of URLs at the end of the response, per the example below.
|
| 28 |
+
|
| 29 |
+
EXAMPLE REFERENCE FORMAT:
|
| 30 |
+
The company has XYZ products [1]. It operates in the software services market which is expected to grow at 10% per year [2].
|
| 31 |
+
|
| 32 |
+
References:
|
| 33 |
+
[1] https://example.com/first-source-url
|
| 34 |
+
[2] https://example.com/second-source-url
|
| 35 |
+
|
| 36 |
+
GUIDELINES:
|
| 37 |
+
* Answer the query directly, do not include unrelated or tangential information.
|
| 38 |
+
* Adhere to any instructions on the length of your final response if provided in the user prompt.
|
| 39 |
+
* If any additional guidelines are provided in the user prompt, follow them exactly and give them precedence over these system instructions.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class WriterAgent:
|
| 44 |
+
"""
|
| 45 |
+
Agent that generates final reports from research findings.
|
| 46 |
+
|
| 47 |
+
Uses Pydantic AI to generate markdown reports with citations.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, model: Any | None = None) -> None:
|
| 51 |
+
"""
|
| 52 |
+
Initialize the writer agent.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
model: Optional Pydantic AI model. If None, uses config default.
|
| 56 |
+
"""
|
| 57 |
+
self.model = model or get_model()
|
| 58 |
+
self.logger = logger
|
| 59 |
+
|
| 60 |
+
# Initialize Pydantic AI Agent (no structured output - returns markdown text)
|
| 61 |
+
self.agent = Agent(
|
| 62 |
+
model=self.model,
|
| 63 |
+
system_prompt=SYSTEM_PROMPT,
|
| 64 |
+
retries=3,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
async def write_report(
|
| 68 |
+
self,
|
| 69 |
+
query: str,
|
| 70 |
+
findings: str,
|
| 71 |
+
output_length: str = "",
|
| 72 |
+
output_instructions: str = "",
|
| 73 |
+
) -> str:
|
| 74 |
+
"""
|
| 75 |
+
Write a final report from findings.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
query: The original research query
|
| 79 |
+
findings: All findings collected during research
|
| 80 |
+
output_length: Optional description of desired output length
|
| 81 |
+
output_instructions: Optional additional instructions
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Markdown formatted report string
|
| 85 |
+
|
| 86 |
+
Raises:
|
| 87 |
+
ConfigurationError: If writing fails
|
| 88 |
+
"""
|
| 89 |
+
# Input validation
|
| 90 |
+
if not query or not query.strip():
|
| 91 |
+
self.logger.warning("Empty query provided, using default")
|
| 92 |
+
query = "Research query"
|
| 93 |
+
|
| 94 |
+
if findings is None:
|
| 95 |
+
self.logger.warning("None findings provided, using empty string")
|
| 96 |
+
findings = "No findings available."
|
| 97 |
+
|
| 98 |
+
# Truncate very long inputs to prevent context overflow
|
| 99 |
+
max_findings_length = 50000 # ~12k tokens
|
| 100 |
+
if len(findings) > max_findings_length:
|
| 101 |
+
self.logger.warning(
|
| 102 |
+
"Findings too long, truncating",
|
| 103 |
+
original_length=len(findings),
|
| 104 |
+
truncated_length=max_findings_length,
|
| 105 |
+
)
|
| 106 |
+
findings = findings[:max_findings_length] + "\n\n[Content truncated due to length]"
|
| 107 |
+
|
| 108 |
+
self.logger.info("Writing final report", query=query[:100], findings_length=len(findings))
|
| 109 |
+
|
| 110 |
+
length_str = (
|
| 111 |
+
f"* The full response should be approximately {output_length}.\n"
|
| 112 |
+
if output_length
|
| 113 |
+
else ""
|
| 114 |
+
)
|
| 115 |
+
instructions_str = f"* {output_instructions}" if output_instructions else ""
|
| 116 |
+
guidelines_str = (
|
| 117 |
+
("\n\nGUIDELINES:\n" + length_str + instructions_str).strip("\n")
|
| 118 |
+
if length_str or instructions_str
|
| 119 |
+
else ""
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
user_message = f"""
|
| 123 |
+
Provide a response based on the query and findings below with as much detail as possible. {guidelines_str}
|
| 124 |
+
|
| 125 |
+
QUERY: {query}
|
| 126 |
+
|
| 127 |
+
FINDINGS:
|
| 128 |
+
{findings}
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
# Retry logic for transient failures
|
| 132 |
+
max_retries = 3
|
| 133 |
+
last_exception: Exception | None = None
|
| 134 |
+
|
| 135 |
+
for attempt in range(max_retries):
|
| 136 |
+
try:
|
| 137 |
+
# Run the agent
|
| 138 |
+
result = await self.agent.run(user_message)
|
| 139 |
+
report = result.output
|
| 140 |
+
|
| 141 |
+
# Validate output
|
| 142 |
+
if not report or not report.strip():
|
| 143 |
+
self.logger.warning("Empty report generated, using fallback")
|
| 144 |
+
raise ValueError("Empty report generated")
|
| 145 |
+
|
| 146 |
+
self.logger.info("Report written", length=len(report), attempt=attempt + 1)
|
| 147 |
+
|
| 148 |
+
return report
|
| 149 |
+
|
| 150 |
+
except (TimeoutError, ConnectionError) as e:
|
| 151 |
+
# Transient errors - retry
|
| 152 |
+
last_exception = e
|
| 153 |
+
if attempt < max_retries - 1:
|
| 154 |
+
self.logger.warning(
|
| 155 |
+
"Transient error, retrying",
|
| 156 |
+
error=str(e),
|
| 157 |
+
attempt=attempt + 1,
|
| 158 |
+
max_retries=max_retries,
|
| 159 |
+
)
|
| 160 |
+
continue
|
| 161 |
+
else:
|
| 162 |
+
self.logger.error("Max retries exceeded for transient error", error=str(e))
|
| 163 |
+
break
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
# Non-transient errors - don't retry
|
| 167 |
+
last_exception = e
|
| 168 |
+
self.logger.error(
|
| 169 |
+
"Report writing failed", error=str(e), error_type=type(e).__name__
|
| 170 |
+
)
|
| 171 |
+
break
|
| 172 |
+
|
| 173 |
+
# Return fallback report if all attempts failed
|
| 174 |
+
self.logger.error(
|
| 175 |
+
"Report writing failed after all attempts",
|
| 176 |
+
error=str(last_exception) if last_exception else "Unknown error",
|
| 177 |
+
)
|
| 178 |
+
# Truncate findings in fallback if too long
|
| 179 |
+
fallback_findings = findings[:500] + "..." if len(findings) > 500 else findings
|
| 180 |
+
return (
|
| 181 |
+
f"# Research Report\n\n"
|
| 182 |
+
f"## Query\n{query}\n\n"
|
| 183 |
+
f"## Findings\n{fallback_findings}\n\n"
|
| 184 |
+
f"*Note: Report generation encountered an error. This is a fallback report.*"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def create_writer_agent(model: Any | None = None) -> WriterAgent:
|
| 189 |
+
"""
|
| 190 |
+
Factory function to create a writer agent.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
Configured WriterAgent instance
|
| 197 |
+
|
| 198 |
+
Raises:
|
| 199 |
+
ConfigurationError: If required API keys are missing
|
| 200 |
+
"""
|
| 201 |
+
try:
|
| 202 |
+
if model is None:
|
| 203 |
+
model = get_model()
|
| 204 |
+
|
| 205 |
+
return WriterAgent(model=model)
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
logger.error("Failed to create writer agent", error=str(e))
|
| 209 |
+
raise ConfigurationError(f"Failed to create writer agent: {e}") from e
|
src/{orchestrator.py → legacy_orchestrator.py}
RENAMED
|
File without changes
|
src/middleware/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Middleware for workflow state management, parallel loop coordination, and budget tracking.
|
| 2 |
+
|
| 3 |
+
This module provides:
|
| 4 |
+
- WorkflowState: Thread-safe state management using ContextVar
|
| 5 |
+
- WorkflowManager: Coordination of parallel research loops
|
| 6 |
+
- BudgetTracker: Token, time, and iteration budget tracking
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from src.middleware.budget_tracker import BudgetStatus, BudgetTracker
|
| 10 |
+
from src.middleware.state_machine import (
|
| 11 |
+
WorkflowState,
|
| 12 |
+
get_workflow_state,
|
| 13 |
+
init_workflow_state,
|
| 14 |
+
)
|
| 15 |
+
from src.middleware.workflow_manager import (
|
| 16 |
+
LoopStatus,
|
| 17 |
+
ResearchLoop,
|
| 18 |
+
WorkflowManager,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
# State management
|
| 23 |
+
"WorkflowState",
|
| 24 |
+
"init_workflow_state",
|
| 25 |
+
"get_workflow_state",
|
| 26 |
+
# Workflow management
|
| 27 |
+
"WorkflowManager",
|
| 28 |
+
"ResearchLoop",
|
| 29 |
+
"LoopStatus",
|
| 30 |
+
# Budget tracking
|
| 31 |
+
"BudgetTracker",
|
| 32 |
+
"BudgetStatus",
|
| 33 |
+
]
|
src/middleware/budget_tracker.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Budget tracking for research loops.
|
| 2 |
+
|
| 3 |
+
Tracks token usage, time elapsed, and iteration counts per loop and globally.
|
| 4 |
+
Enforces budget constraints to prevent infinite loops and excessive resource usage.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
import structlog
|
| 10 |
+
from pydantic import BaseModel, Field
|
| 11 |
+
|
| 12 |
+
logger = structlog.get_logger()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BudgetStatus(BaseModel):
|
| 16 |
+
"""Status of a budget (tokens, time, iterations)."""
|
| 17 |
+
|
| 18 |
+
tokens_used: int = Field(default=0, description="Total tokens used")
|
| 19 |
+
tokens_limit: int = Field(default=100000, description="Token budget limit", ge=0)
|
| 20 |
+
time_elapsed_seconds: float = Field(default=0.0, description="Time elapsed", ge=0.0)
|
| 21 |
+
time_limit_seconds: float = Field(
|
| 22 |
+
default=600.0, description="Time budget limit (10 min default)", ge=0.0
|
| 23 |
+
)
|
| 24 |
+
iterations: int = Field(default=0, description="Number of iterations completed", ge=0)
|
| 25 |
+
iterations_limit: int = Field(default=10, description="Maximum iterations", ge=1)
|
| 26 |
+
iteration_tokens: dict[int, int] = Field(
|
| 27 |
+
default_factory=dict,
|
| 28 |
+
description="Tokens used per iteration (iteration number -> token count)",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def is_exceeded(self) -> bool:
|
| 32 |
+
"""Check if any budget limit has been exceeded.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
True if any limit is exceeded, False otherwise.
|
| 36 |
+
"""
|
| 37 |
+
return (
|
| 38 |
+
self.tokens_used >= self.tokens_limit
|
| 39 |
+
or self.time_elapsed_seconds >= self.time_limit_seconds
|
| 40 |
+
or self.iterations >= self.iterations_limit
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def remaining_tokens(self) -> int:
|
| 44 |
+
"""Get remaining token budget.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Remaining tokens (may be negative if exceeded).
|
| 48 |
+
"""
|
| 49 |
+
return self.tokens_limit - self.tokens_used
|
| 50 |
+
|
| 51 |
+
def remaining_time_seconds(self) -> float:
|
| 52 |
+
"""Get remaining time budget.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Remaining time in seconds (may be negative if exceeded).
|
| 56 |
+
"""
|
| 57 |
+
return self.time_limit_seconds - self.time_elapsed_seconds
|
| 58 |
+
|
| 59 |
+
def remaining_iterations(self) -> int:
|
| 60 |
+
"""Get remaining iteration budget.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Remaining iterations (may be negative if exceeded).
|
| 64 |
+
"""
|
| 65 |
+
return self.iterations_limit - self.iterations
|
| 66 |
+
|
| 67 |
+
def add_iteration_tokens(self, iteration: int, tokens: int) -> None:
|
| 68 |
+
"""Add tokens for a specific iteration.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
iteration: Iteration number (1-indexed).
|
| 72 |
+
tokens: Number of tokens to add.
|
| 73 |
+
"""
|
| 74 |
+
if iteration not in self.iteration_tokens:
|
| 75 |
+
self.iteration_tokens[iteration] = 0
|
| 76 |
+
self.iteration_tokens[iteration] += tokens
|
| 77 |
+
# Also add to total tokens
|
| 78 |
+
self.tokens_used += tokens
|
| 79 |
+
|
| 80 |
+
def get_iteration_tokens(self, iteration: int) -> int:
|
| 81 |
+
"""Get tokens used for a specific iteration.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
iteration: Iteration number.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Token count for the iteration, or 0 if not found.
|
| 88 |
+
"""
|
| 89 |
+
return self.iteration_tokens.get(iteration, 0)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class BudgetTracker:
|
| 93 |
+
"""Tracks budgets per loop and globally."""
|
| 94 |
+
|
| 95 |
+
def __init__(self) -> None:
|
| 96 |
+
"""Initialize the budget tracker."""
|
| 97 |
+
self._budgets: dict[str, BudgetStatus] = {}
|
| 98 |
+
self._start_times: dict[str, float] = {}
|
| 99 |
+
self._global_budget: BudgetStatus | None = None
|
| 100 |
+
|
| 101 |
+
def create_budget(
|
| 102 |
+
self,
|
| 103 |
+
loop_id: str,
|
| 104 |
+
tokens_limit: int = 100000,
|
| 105 |
+
time_limit_seconds: float = 600.0,
|
| 106 |
+
iterations_limit: int = 10,
|
| 107 |
+
) -> BudgetStatus:
|
| 108 |
+
"""Create a budget for a specific loop.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
loop_id: Unique identifier for the loop.
|
| 112 |
+
tokens_limit: Maximum tokens allowed.
|
| 113 |
+
time_limit_seconds: Maximum time allowed in seconds.
|
| 114 |
+
iterations_limit: Maximum iterations allowed.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
The created BudgetStatus instance.
|
| 118 |
+
"""
|
| 119 |
+
budget = BudgetStatus(
|
| 120 |
+
tokens_limit=tokens_limit,
|
| 121 |
+
time_limit_seconds=time_limit_seconds,
|
| 122 |
+
iterations_limit=iterations_limit,
|
| 123 |
+
)
|
| 124 |
+
self._budgets[loop_id] = budget
|
| 125 |
+
logger.debug(
|
| 126 |
+
"Budget created",
|
| 127 |
+
loop_id=loop_id,
|
| 128 |
+
tokens_limit=tokens_limit,
|
| 129 |
+
time_limit=time_limit_seconds,
|
| 130 |
+
iterations_limit=iterations_limit,
|
| 131 |
+
)
|
| 132 |
+
return budget
|
| 133 |
+
|
| 134 |
+
def get_budget(self, loop_id: str) -> BudgetStatus | None:
|
| 135 |
+
"""Get the budget for a specific loop.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
loop_id: Unique identifier for the loop.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
The BudgetStatus instance, or None if not found.
|
| 142 |
+
"""
|
| 143 |
+
return self._budgets.get(loop_id)
|
| 144 |
+
|
| 145 |
+
def add_tokens(self, loop_id: str, tokens: int) -> None:
|
| 146 |
+
"""Add tokens to a loop's budget.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
loop_id: Unique identifier for the loop.
|
| 150 |
+
tokens: Number of tokens to add (can be negative).
|
| 151 |
+
"""
|
| 152 |
+
if loop_id not in self._budgets:
|
| 153 |
+
logger.warning("Budget not found for loop", loop_id=loop_id)
|
| 154 |
+
return
|
| 155 |
+
self._budgets[loop_id].tokens_used += tokens
|
| 156 |
+
logger.debug("Tokens added", loop_id=loop_id, tokens=tokens)
|
| 157 |
+
|
| 158 |
+
def add_iteration_tokens(self, loop_id: str, iteration: int, tokens: int) -> None:
|
| 159 |
+
"""Add tokens for a specific iteration.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
loop_id: Loop identifier.
|
| 163 |
+
iteration: Iteration number (1-indexed).
|
| 164 |
+
tokens: Number of tokens to add.
|
| 165 |
+
"""
|
| 166 |
+
if loop_id not in self._budgets:
|
| 167 |
+
logger.warning("Budget not found for loop", loop_id=loop_id)
|
| 168 |
+
return
|
| 169 |
+
|
| 170 |
+
budget = self._budgets[loop_id]
|
| 171 |
+
budget.add_iteration_tokens(iteration, tokens)
|
| 172 |
+
|
| 173 |
+
logger.debug(
|
| 174 |
+
"Iteration tokens added",
|
| 175 |
+
loop_id=loop_id,
|
| 176 |
+
iteration=iteration,
|
| 177 |
+
tokens=tokens,
|
| 178 |
+
total_iteration=budget.get_iteration_tokens(iteration),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def get_iteration_tokens(self, loop_id: str, iteration: int) -> int:
|
| 182 |
+
"""Get tokens used for a specific iteration.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
loop_id: Loop identifier.
|
| 186 |
+
iteration: Iteration number.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Token count for the iteration, or 0 if not found.
|
| 190 |
+
"""
|
| 191 |
+
if loop_id not in self._budgets:
|
| 192 |
+
return 0
|
| 193 |
+
|
| 194 |
+
return self._budgets[loop_id].get_iteration_tokens(iteration)
|
| 195 |
+
|
| 196 |
+
def start_timer(self, loop_id: str) -> None:
|
| 197 |
+
"""Start the timer for a loop.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
loop_id: Unique identifier for the loop.
|
| 201 |
+
"""
|
| 202 |
+
self._start_times[loop_id] = time.time()
|
| 203 |
+
logger.debug("Timer started", loop_id=loop_id)
|
| 204 |
+
|
| 205 |
+
def update_timer(self, loop_id: str) -> None:
|
| 206 |
+
"""Update the elapsed time for a loop.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
loop_id: Unique identifier for the loop.
|
| 210 |
+
"""
|
| 211 |
+
if loop_id not in self._start_times:
|
| 212 |
+
logger.warning("Timer not started for loop", loop_id=loop_id)
|
| 213 |
+
return
|
| 214 |
+
if loop_id not in self._budgets:
|
| 215 |
+
logger.warning("Budget not found for loop", loop_id=loop_id)
|
| 216 |
+
return
|
| 217 |
+
|
| 218 |
+
elapsed = time.time() - self._start_times[loop_id]
|
| 219 |
+
self._budgets[loop_id].time_elapsed_seconds = elapsed
|
| 220 |
+
logger.debug("Timer updated", loop_id=loop_id, elapsed=elapsed)
|
| 221 |
+
|
| 222 |
+
def increment_iteration(self, loop_id: str) -> None:
|
| 223 |
+
"""Increment the iteration count for a loop.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
loop_id: Unique identifier for the loop.
|
| 227 |
+
"""
|
| 228 |
+
if loop_id not in self._budgets:
|
| 229 |
+
logger.warning("Budget not found for loop", loop_id=loop_id)
|
| 230 |
+
return
|
| 231 |
+
self._budgets[loop_id].iterations += 1
|
| 232 |
+
logger.debug(
|
| 233 |
+
"Iteration incremented",
|
| 234 |
+
loop_id=loop_id,
|
| 235 |
+
iterations=self._budgets[loop_id].iterations,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
def check_budget(self, loop_id: str) -> tuple[bool, str]:
|
| 239 |
+
"""Check if a loop's budget has been exceeded.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
loop_id: Unique identifier for the loop.
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
Tuple of (exceeded: bool, reason: str). Reason is empty if not exceeded.
|
| 246 |
+
"""
|
| 247 |
+
if loop_id not in self._budgets:
|
| 248 |
+
return False, ""
|
| 249 |
+
|
| 250 |
+
budget = self._budgets[loop_id]
|
| 251 |
+
self.update_timer(loop_id) # Update time before checking
|
| 252 |
+
|
| 253 |
+
if budget.is_exceeded():
|
| 254 |
+
reasons = []
|
| 255 |
+
if budget.tokens_used >= budget.tokens_limit:
|
| 256 |
+
reasons.append("tokens")
|
| 257 |
+
if budget.time_elapsed_seconds >= budget.time_limit_seconds:
|
| 258 |
+
reasons.append("time")
|
| 259 |
+
if budget.iterations >= budget.iterations_limit:
|
| 260 |
+
reasons.append("iterations")
|
| 261 |
+
reason = f"Budget exceeded: {', '.join(reasons)}"
|
| 262 |
+
logger.warning("Budget exceeded", loop_id=loop_id, reason=reason)
|
| 263 |
+
return True, reason
|
| 264 |
+
|
| 265 |
+
return False, ""
|
| 266 |
+
|
| 267 |
+
def can_continue(self, loop_id: str) -> bool:
|
| 268 |
+
"""Check if a loop can continue based on budget.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
loop_id: Unique identifier for the loop.
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
True if the loop can continue, False if budget is exceeded.
|
| 275 |
+
"""
|
| 276 |
+
exceeded, _ = self.check_budget(loop_id)
|
| 277 |
+
return not exceeded
|
| 278 |
+
|
| 279 |
+
def get_budget_summary(self, loop_id: str) -> str:
|
| 280 |
+
"""Get a formatted summary of a loop's budget status.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
loop_id: Unique identifier for the loop.
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
Formatted string summary.
|
| 287 |
+
"""
|
| 288 |
+
if loop_id not in self._budgets:
|
| 289 |
+
return f"Budget not found for loop: {loop_id}"
|
| 290 |
+
|
| 291 |
+
budget = self._budgets[loop_id]
|
| 292 |
+
self.update_timer(loop_id)
|
| 293 |
+
|
| 294 |
+
return (
|
| 295 |
+
f"Loop {loop_id}: "
|
| 296 |
+
f"Tokens: {budget.tokens_used}/{budget.tokens_limit} "
|
| 297 |
+
f"({budget.remaining_tokens()} remaining), "
|
| 298 |
+
f"Time: {budget.time_elapsed_seconds:.1f}/{budget.time_limit_seconds:.1f}s "
|
| 299 |
+
f"({budget.remaining_time_seconds():.1f}s remaining), "
|
| 300 |
+
f"Iterations: {budget.iterations}/{budget.iterations_limit} "
|
| 301 |
+
f"({budget.remaining_iterations()} remaining)"
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
def reset_budget(self, loop_id: str) -> None:
|
| 305 |
+
"""Reset the budget for a loop.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
loop_id: Unique identifier for the loop.
|
| 309 |
+
"""
|
| 310 |
+
if loop_id in self._budgets:
|
| 311 |
+
old_budget = self._budgets[loop_id]
|
| 312 |
+
# Preserve iteration_tokens when resetting
|
| 313 |
+
old_iteration_tokens = old_budget.iteration_tokens
|
| 314 |
+
self._budgets[loop_id] = BudgetStatus(
|
| 315 |
+
tokens_limit=old_budget.tokens_limit,
|
| 316 |
+
time_limit_seconds=old_budget.time_limit_seconds,
|
| 317 |
+
iterations_limit=old_budget.iterations_limit,
|
| 318 |
+
iteration_tokens=old_iteration_tokens, # Restore old iteration tokens
|
| 319 |
+
)
|
| 320 |
+
if loop_id in self._start_times:
|
| 321 |
+
self._start_times[loop_id] = time.time()
|
| 322 |
+
logger.debug("Budget reset", loop_id=loop_id)
|
| 323 |
+
|
| 324 |
+
def set_global_budget(
|
| 325 |
+
self,
|
| 326 |
+
tokens_limit: int = 100000,
|
| 327 |
+
time_limit_seconds: float = 600.0,
|
| 328 |
+
iterations_limit: int = 10,
|
| 329 |
+
) -> None:
|
| 330 |
+
"""Set a global budget that applies to all loops.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
tokens_limit: Maximum tokens allowed globally.
|
| 334 |
+
time_limit_seconds: Maximum time allowed in seconds.
|
| 335 |
+
iterations_limit: Maximum iterations allowed globally.
|
| 336 |
+
"""
|
| 337 |
+
self._global_budget = BudgetStatus(
|
| 338 |
+
tokens_limit=tokens_limit,
|
| 339 |
+
time_limit_seconds=time_limit_seconds,
|
| 340 |
+
iterations_limit=iterations_limit,
|
| 341 |
+
)
|
| 342 |
+
logger.debug(
|
| 343 |
+
"Global budget set",
|
| 344 |
+
tokens_limit=tokens_limit,
|
| 345 |
+
time_limit=time_limit_seconds,
|
| 346 |
+
iterations_limit=iterations_limit,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
def get_global_budget(self) -> BudgetStatus | None:
|
| 350 |
+
"""Get the global budget.
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
The global BudgetStatus instance, or None if not set.
|
| 354 |
+
"""
|
| 355 |
+
return self._global_budget
|
| 356 |
+
|
| 357 |
+
def add_global_tokens(self, tokens: int) -> None:
|
| 358 |
+
"""Add tokens to the global budget.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
tokens: Number of tokens to add (can be negative).
|
| 362 |
+
"""
|
| 363 |
+
if self._global_budget is None:
|
| 364 |
+
logger.warning("Global budget not set")
|
| 365 |
+
return
|
| 366 |
+
self._global_budget.tokens_used += tokens
|
| 367 |
+
logger.debug("Global tokens added", tokens=tokens)
|
| 368 |
+
|
| 369 |
+
def estimate_tokens(self, text: str) -> int:
|
| 370 |
+
"""Estimate token count from text (rough estimate: ~4 chars per token).
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
text: Text to estimate tokens for.
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
Estimated token count.
|
| 377 |
+
"""
|
| 378 |
+
return len(text) // 4
|
| 379 |
+
|
| 380 |
+
def estimate_llm_call_tokens(self, prompt: str, response: str) -> int:
|
| 381 |
+
"""Estimate token count for an LLM call.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
prompt: The prompt text.
|
| 385 |
+
response: The response text.
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
Estimated total token count (prompt + response).
|
| 389 |
+
"""
|
| 390 |
+
return self.estimate_tokens(prompt) + self.estimate_tokens(response)
|
src/middleware/state_machine.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Thread-safe state management for workflow agents.
|
| 2 |
+
|
| 3 |
+
Uses contextvars to ensure isolation between concurrent requests (e.g., multiple users
|
| 4 |
+
searching simultaneously via Gradio). Refactored from MagenticState to support both
|
| 5 |
+
iterative and deep research patterns.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from contextvars import ContextVar
|
| 9 |
+
from typing import TYPE_CHECKING, Any
|
| 10 |
+
|
| 11 |
+
import structlog
|
| 12 |
+
from pydantic import BaseModel, Field
|
| 13 |
+
|
| 14 |
+
from src.utils.models import Citation, Conversation, Evidence
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from src.services.embeddings import EmbeddingService
|
| 18 |
+
|
| 19 |
+
logger = structlog.get_logger()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class WorkflowState(BaseModel):
|
| 23 |
+
"""Mutable state for a workflow session.
|
| 24 |
+
|
| 25 |
+
Supports both iterative and deep research patterns by tracking evidence,
|
| 26 |
+
conversation history, and providing semantic search capabilities.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
evidence: list[Evidence] = Field(default_factory=list)
|
| 30 |
+
conversation: Conversation = Field(default_factory=Conversation)
|
| 31 |
+
# Type as Any to avoid circular imports/runtime resolution issues
|
| 32 |
+
# The actual object injected will be an EmbeddingService instance
|
| 33 |
+
embedding_service: Any = Field(default=None)
|
| 34 |
+
|
| 35 |
+
model_config = {"arbitrary_types_allowed": True}
|
| 36 |
+
|
| 37 |
+
def add_evidence(self, new_evidence: list[Evidence]) -> int:
|
| 38 |
+
"""Add new evidence, deduplicating by URL.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
new_evidence: List of Evidence objects to add.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Number of *new* items added (excluding duplicates).
|
| 45 |
+
"""
|
| 46 |
+
existing_urls = {e.citation.url for e in self.evidence}
|
| 47 |
+
count = 0
|
| 48 |
+
for item in new_evidence:
|
| 49 |
+
if item.citation.url not in existing_urls:
|
| 50 |
+
self.evidence.append(item)
|
| 51 |
+
existing_urls.add(item.citation.url)
|
| 52 |
+
count += 1
|
| 53 |
+
return count
|
| 54 |
+
|
| 55 |
+
async def search_related(self, query: str, n_results: int = 5) -> list[Evidence]:
|
| 56 |
+
"""Search for semantically related evidence using the embedding service.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
query: Search query string.
|
| 60 |
+
n_results: Maximum number of results to return.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
List of Evidence objects, ordered by relevance.
|
| 64 |
+
"""
|
| 65 |
+
if not self.embedding_service:
|
| 66 |
+
logger.warning("Embedding service not available, returning empty results")
|
| 67 |
+
return []
|
| 68 |
+
|
| 69 |
+
results = await self.embedding_service.search_similar(query, n_results=n_results)
|
| 70 |
+
|
| 71 |
+
# Convert dict results back to Evidence objects
|
| 72 |
+
evidence_list = []
|
| 73 |
+
for item in results:
|
| 74 |
+
meta = item.get("metadata", {})
|
| 75 |
+
authors_str = meta.get("authors", "")
|
| 76 |
+
authors = [a.strip() for a in authors_str.split(",") if a.strip()]
|
| 77 |
+
|
| 78 |
+
ev = Evidence(
|
| 79 |
+
content=item["content"],
|
| 80 |
+
citation=Citation(
|
| 81 |
+
title=meta.get("title", "Related Evidence"),
|
| 82 |
+
url=item["id"],
|
| 83 |
+
source="pubmed", # Defaulting to pubmed if unknown
|
| 84 |
+
date=meta.get("date", "n.d."),
|
| 85 |
+
authors=authors,
|
| 86 |
+
),
|
| 87 |
+
relevance=max(0.0, 1.0 - item.get("distance", 0.5)),
|
| 88 |
+
)
|
| 89 |
+
evidence_list.append(ev)
|
| 90 |
+
|
| 91 |
+
return evidence_list
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# The ContextVar holds the WorkflowState for the current execution context
|
| 95 |
+
_workflow_state_var: ContextVar[WorkflowState | None] = ContextVar("workflow_state", default=None)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def init_workflow_state(
|
| 99 |
+
embedding_service: "EmbeddingService | None" = None,
|
| 100 |
+
) -> WorkflowState:
|
| 101 |
+
"""Initialize a new state for the current context.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
embedding_service: Optional embedding service for semantic search.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
The initialized WorkflowState instance.
|
| 108 |
+
"""
|
| 109 |
+
state = WorkflowState(embedding_service=embedding_service)
|
| 110 |
+
_workflow_state_var.set(state)
|
| 111 |
+
logger.debug("Workflow state initialized", has_embeddings=embedding_service is not None)
|
| 112 |
+
return state
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_workflow_state() -> WorkflowState:
|
| 116 |
+
"""Get the current state. Auto-initializes if not set.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
The current WorkflowState instance.
|
| 120 |
+
|
| 121 |
+
Raises:
|
| 122 |
+
RuntimeError: If state is not initialized and auto-initialization fails.
|
| 123 |
+
"""
|
| 124 |
+
state = _workflow_state_var.get()
|
| 125 |
+
if state is None:
|
| 126 |
+
# Auto-initialize if missing (e.g. during tests or simple scripts)
|
| 127 |
+
logger.debug("Workflow state not found, auto-initializing")
|
| 128 |
+
return init_workflow_state()
|
| 129 |
+
return state
|
src/middleware/workflow_manager.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Workflow manager for coordinating parallel research loops.
|
| 2 |
+
|
| 3 |
+
Manages multiple research loops running in parallel, tracks their status,
|
| 4 |
+
and synchronizes evidence between loops and the global state.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
from collections.abc import Callable
|
| 9 |
+
from typing import Any, Literal
|
| 10 |
+
|
| 11 |
+
import structlog
|
| 12 |
+
from pydantic import BaseModel, Field
|
| 13 |
+
|
| 14 |
+
from src.middleware.state_machine import get_workflow_state
|
| 15 |
+
from src.utils.models import Evidence
|
| 16 |
+
|
| 17 |
+
logger = structlog.get_logger()
|
| 18 |
+
|
| 19 |
+
LoopStatus = Literal["pending", "running", "completed", "failed", "cancelled"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ResearchLoop(BaseModel):
|
| 23 |
+
"""Represents a single research loop."""
|
| 24 |
+
|
| 25 |
+
loop_id: str = Field(description="Unique identifier for the loop")
|
| 26 |
+
query: str = Field(description="The research query for this loop")
|
| 27 |
+
status: LoopStatus = Field(default="pending")
|
| 28 |
+
evidence: list[Evidence] = Field(default_factory=list)
|
| 29 |
+
iteration_count: int = Field(default=0, ge=0)
|
| 30 |
+
error: str | None = Field(default=None)
|
| 31 |
+
|
| 32 |
+
model_config = {"frozen": False} # Mutable for status updates
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class WorkflowManager:
|
| 36 |
+
"""Manages parallel research loops and state synchronization."""
|
| 37 |
+
|
| 38 |
+
def __init__(self) -> None:
|
| 39 |
+
"""Initialize the workflow manager."""
|
| 40 |
+
self._loops: dict[str, ResearchLoop] = {}
|
| 41 |
+
|
| 42 |
+
async def add_loop(self, loop_id: str, query: str) -> ResearchLoop:
|
| 43 |
+
"""Add a new research loop.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
loop_id: Unique identifier for the loop.
|
| 47 |
+
query: The research query for this loop.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
The created ResearchLoop instance.
|
| 51 |
+
"""
|
| 52 |
+
loop = ResearchLoop(loop_id=loop_id, query=query, status="pending")
|
| 53 |
+
self._loops[loop_id] = loop
|
| 54 |
+
logger.info("Loop added", loop_id=loop_id, query=query)
|
| 55 |
+
return loop
|
| 56 |
+
|
| 57 |
+
async def get_loop(self, loop_id: str) -> ResearchLoop | None:
|
| 58 |
+
"""Get a research loop by ID.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
loop_id: Unique identifier for the loop.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
The ResearchLoop instance, or None if not found.
|
| 65 |
+
"""
|
| 66 |
+
return self._loops.get(loop_id)
|
| 67 |
+
|
| 68 |
+
async def update_loop_status(
|
| 69 |
+
self, loop_id: str, status: LoopStatus, error: str | None = None
|
| 70 |
+
) -> None:
|
| 71 |
+
"""Update the status of a research loop.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
loop_id: Unique identifier for the loop.
|
| 75 |
+
status: New status for the loop.
|
| 76 |
+
error: Optional error message if status is "failed".
|
| 77 |
+
"""
|
| 78 |
+
if loop_id not in self._loops:
|
| 79 |
+
logger.warning("Loop not found", loop_id=loop_id)
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
self._loops[loop_id].status = status
|
| 83 |
+
if error:
|
| 84 |
+
self._loops[loop_id].error = error
|
| 85 |
+
logger.info("Loop status updated", loop_id=loop_id, status=status)
|
| 86 |
+
|
| 87 |
+
async def add_loop_evidence(self, loop_id: str, evidence: list[Evidence]) -> None:
|
| 88 |
+
"""Add evidence to a research loop.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
loop_id: Unique identifier for the loop.
|
| 92 |
+
evidence: List of Evidence objects to add.
|
| 93 |
+
"""
|
| 94 |
+
if loop_id not in self._loops:
|
| 95 |
+
logger.warning("Loop not found", loop_id=loop_id)
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
self._loops[loop_id].evidence.extend(evidence)
|
| 99 |
+
logger.debug(
|
| 100 |
+
"Evidence added to loop",
|
| 101 |
+
loop_id=loop_id,
|
| 102 |
+
evidence_count=len(evidence),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
async def increment_loop_iteration(self, loop_id: str) -> None:
|
| 106 |
+
"""Increment the iteration count for a research loop.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
loop_id: Unique identifier for the loop.
|
| 110 |
+
"""
|
| 111 |
+
if loop_id not in self._loops:
|
| 112 |
+
logger.warning("Loop not found", loop_id=loop_id)
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
self._loops[loop_id].iteration_count += 1
|
| 116 |
+
logger.debug(
|
| 117 |
+
"Iteration incremented",
|
| 118 |
+
loop_id=loop_id,
|
| 119 |
+
iteration=self._loops[loop_id].iteration_count,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
async def run_loops_parallel(
|
| 123 |
+
self,
|
| 124 |
+
loop_configs: list[dict[str, Any]],
|
| 125 |
+
loop_func: Callable[[dict[str, Any]], Any],
|
| 126 |
+
judge_handler: Any | None = None,
|
| 127 |
+
budget_tracker: Any | None = None,
|
| 128 |
+
) -> list[Any]:
|
| 129 |
+
"""Run multiple research loops in parallel.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
loop_configs: List of configuration dicts, each must contain 'loop_id' and 'query'.
|
| 133 |
+
loop_func: Async function that takes a config dict and returns loop results.
|
| 134 |
+
judge_handler: Optional JudgeHandler for early termination based on evidence sufficiency.
|
| 135 |
+
budget_tracker: Optional BudgetTracker for budget enforcement.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
List of results from each loop (in order of completion, not original order).
|
| 139 |
+
"""
|
| 140 |
+
logger.info("Starting parallel loops", loop_count=len(loop_configs))
|
| 141 |
+
|
| 142 |
+
# Create loops
|
| 143 |
+
for config in loop_configs:
|
| 144 |
+
loop_id = config.get("loop_id")
|
| 145 |
+
query = config.get("query", "")
|
| 146 |
+
if loop_id:
|
| 147 |
+
await self.add_loop(loop_id, query)
|
| 148 |
+
await self.update_loop_status(loop_id, "running")
|
| 149 |
+
|
| 150 |
+
# Run loops in parallel
|
| 151 |
+
async def run_single_loop(config: dict[str, Any]) -> Any:
|
| 152 |
+
loop_id = config.get("loop_id", "unknown")
|
| 153 |
+
query = config.get("query", "")
|
| 154 |
+
try:
|
| 155 |
+
# Check budget before starting
|
| 156 |
+
if budget_tracker:
|
| 157 |
+
exceeded, reason = budget_tracker.check_budget(loop_id)
|
| 158 |
+
if exceeded:
|
| 159 |
+
await self.update_loop_status(loop_id, "cancelled", error=reason)
|
| 160 |
+
logger.warning(
|
| 161 |
+
"Loop cancelled due to budget", loop_id=loop_id, reason=reason
|
| 162 |
+
)
|
| 163 |
+
return None
|
| 164 |
+
|
| 165 |
+
# If loop_func supports periodic checkpoints, we could check judge here
|
| 166 |
+
# For now, the loop_func itself handles judge checks internally
|
| 167 |
+
result = await loop_func(config)
|
| 168 |
+
|
| 169 |
+
# Final check with judge if available
|
| 170 |
+
if judge_handler and query:
|
| 171 |
+
should_complete, reason = await self.check_loop_completion(
|
| 172 |
+
loop_id, query, judge_handler
|
| 173 |
+
)
|
| 174 |
+
if should_complete:
|
| 175 |
+
logger.info(
|
| 176 |
+
"Loop completed early based on judge assessment",
|
| 177 |
+
loop_id=loop_id,
|
| 178 |
+
reason=reason,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
await self.update_loop_status(loop_id, "completed")
|
| 182 |
+
return result
|
| 183 |
+
except Exception as e:
|
| 184 |
+
error_msg = str(e)
|
| 185 |
+
await self.update_loop_status(loop_id, "failed", error=error_msg)
|
| 186 |
+
logger.error("Loop failed", loop_id=loop_id, error=error_msg)
|
| 187 |
+
raise
|
| 188 |
+
|
| 189 |
+
results = await asyncio.gather(
|
| 190 |
+
*(run_single_loop(config) for config in loop_configs),
|
| 191 |
+
return_exceptions=True,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Log completion
|
| 195 |
+
completed = sum(1 for r in results if not isinstance(r, Exception))
|
| 196 |
+
failed = len(results) - completed
|
| 197 |
+
logger.info(
|
| 198 |
+
"Parallel loops completed",
|
| 199 |
+
total=len(loop_configs),
|
| 200 |
+
completed=completed,
|
| 201 |
+
failed=failed,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
return results
|
| 205 |
+
|
| 206 |
+
async def wait_for_loops(
|
| 207 |
+
self, loop_ids: list[str], timeout: float | None = None
|
| 208 |
+
) -> list[ResearchLoop]:
|
| 209 |
+
"""Wait for loops to complete.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
loop_ids: List of loop IDs to wait for.
|
| 213 |
+
timeout: Optional timeout in seconds.
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
List of ResearchLoop instances (may be incomplete if timeout occurs).
|
| 217 |
+
"""
|
| 218 |
+
start_time = asyncio.get_event_loop().time()
|
| 219 |
+
|
| 220 |
+
while True:
|
| 221 |
+
loops = [self._loops.get(loop_id) for loop_id in loop_ids]
|
| 222 |
+
all_complete = all(
|
| 223 |
+
loop and loop.status in ("completed", "failed", "cancelled") for loop in loops
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
if all_complete:
|
| 227 |
+
return [loop for loop in loops if loop is not None]
|
| 228 |
+
|
| 229 |
+
if timeout is not None:
|
| 230 |
+
elapsed = asyncio.get_event_loop().time() - start_time
|
| 231 |
+
if elapsed >= timeout:
|
| 232 |
+
logger.warning("Timeout waiting for loops", timeout=timeout)
|
| 233 |
+
return [loop for loop in loops if loop is not None]
|
| 234 |
+
|
| 235 |
+
await asyncio.sleep(0.1) # Small delay to avoid busy waiting
|
| 236 |
+
|
| 237 |
+
async def cancel_loop(self, loop_id: str) -> None:
|
| 238 |
+
"""Cancel a research loop.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
loop_id: Unique identifier for the loop.
|
| 242 |
+
"""
|
| 243 |
+
await self.update_loop_status(loop_id, "cancelled")
|
| 244 |
+
logger.info("Loop cancelled", loop_id=loop_id)
|
| 245 |
+
|
| 246 |
+
async def get_all_loops(self) -> list[ResearchLoop]:
|
| 247 |
+
"""Get all research loops.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
List of all ResearchLoop instances.
|
| 251 |
+
"""
|
| 252 |
+
return list(self._loops.values())
|
| 253 |
+
|
| 254 |
+
async def sync_loop_evidence_to_state(self, loop_id: str) -> None:
|
| 255 |
+
"""Synchronize evidence from a loop to the global state.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
loop_id: Unique identifier for the loop.
|
| 259 |
+
"""
|
| 260 |
+
if loop_id not in self._loops:
|
| 261 |
+
logger.warning("Loop not found", loop_id=loop_id)
|
| 262 |
+
return
|
| 263 |
+
|
| 264 |
+
loop = self._loops[loop_id]
|
| 265 |
+
state = get_workflow_state()
|
| 266 |
+
added_count = state.add_evidence(loop.evidence)
|
| 267 |
+
logger.debug(
|
| 268 |
+
"Loop evidence synced to state",
|
| 269 |
+
loop_id=loop_id,
|
| 270 |
+
evidence_count=len(loop.evidence),
|
| 271 |
+
added_count=added_count,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
async def get_shared_evidence(self) -> list[Evidence]:
|
| 275 |
+
"""Get evidence from the global state.
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
List of Evidence objects from the global state.
|
| 279 |
+
"""
|
| 280 |
+
state = get_workflow_state()
|
| 281 |
+
return state.evidence
|
| 282 |
+
|
| 283 |
+
async def get_loop_evidence(self, loop_id: str) -> list[Evidence]:
|
| 284 |
+
"""Get evidence collected by a specific loop.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
loop_id: Loop identifier.
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
List of Evidence objects from the loop.
|
| 291 |
+
"""
|
| 292 |
+
if loop_id not in self._loops:
|
| 293 |
+
return []
|
| 294 |
+
|
| 295 |
+
return self._loops[loop_id].evidence
|
| 296 |
+
|
| 297 |
+
async def check_loop_completion(
|
| 298 |
+
self, loop_id: str, query: str, judge_handler: Any
|
| 299 |
+
) -> tuple[bool, str]:
|
| 300 |
+
"""Check if a loop should complete using judge assessment.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
loop_id: Loop identifier.
|
| 304 |
+
query: Research query.
|
| 305 |
+
judge_handler: JudgeHandler instance.
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
Tuple of (should_complete: bool, reason: str).
|
| 309 |
+
"""
|
| 310 |
+
evidence = await self.get_loop_evidence(loop_id)
|
| 311 |
+
|
| 312 |
+
if not evidence:
|
| 313 |
+
return False, "No evidence collected yet"
|
| 314 |
+
|
| 315 |
+
try:
|
| 316 |
+
assessment = await judge_handler.assess(query, evidence)
|
| 317 |
+
if assessment.sufficient:
|
| 318 |
+
return True, f"Judge assessment: {assessment.reasoning}"
|
| 319 |
+
return False, f"Judge assessment: {assessment.reasoning}"
|
| 320 |
+
except Exception as e:
|
| 321 |
+
logger.error("Judge assessment failed", error=str(e), loop_id=loop_id)
|
| 322 |
+
return False, f"Judge assessment failed: {e!s}"
|
src/orchestrator/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Orchestrator module for research flows and planner agent.
|
| 2 |
+
|
| 3 |
+
This module provides:
|
| 4 |
+
- PlannerAgent: Creates report plans with sections
|
| 5 |
+
- IterativeResearchFlow: Single research loop pattern
|
| 6 |
+
- DeepResearchFlow: Parallel research loops pattern
|
| 7 |
+
- GraphOrchestrator: Stub for Phase 4 (uses agent chains for now)
|
| 8 |
+
- Protocols: SearchHandlerProtocol, JudgeHandlerProtocol (re-exported from legacy_orchestrator)
|
| 9 |
+
- Orchestrator: Legacy orchestrator class (re-exported from legacy_orchestrator)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from typing import TYPE_CHECKING
|
| 13 |
+
|
| 14 |
+
# Re-export protocols and Orchestrator from legacy_orchestrator for backward compatibility
|
| 15 |
+
from src.legacy_orchestrator import (
|
| 16 |
+
JudgeHandlerProtocol,
|
| 17 |
+
Orchestrator,
|
| 18 |
+
SearchHandlerProtocol,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# Lazy imports to avoid circular dependencies
|
| 22 |
+
if TYPE_CHECKING:
|
| 23 |
+
from src.orchestrator.graph_orchestrator import GraphOrchestrator
|
| 24 |
+
from src.orchestrator.planner_agent import PlannerAgent, create_planner_agent
|
| 25 |
+
from src.orchestrator.research_flow import (
|
| 26 |
+
DeepResearchFlow,
|
| 27 |
+
IterativeResearchFlow,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Public exports
|
| 31 |
+
from src.orchestrator.graph_orchestrator import (
|
| 32 |
+
GraphOrchestrator,
|
| 33 |
+
create_graph_orchestrator,
|
| 34 |
+
)
|
| 35 |
+
from src.orchestrator.planner_agent import PlannerAgent, create_planner_agent
|
| 36 |
+
from src.orchestrator.research_flow import DeepResearchFlow, IterativeResearchFlow
|
| 37 |
+
|
| 38 |
+
__all__ = [
|
| 39 |
+
"PlannerAgent",
|
| 40 |
+
"create_planner_agent",
|
| 41 |
+
"IterativeResearchFlow",
|
| 42 |
+
"DeepResearchFlow",
|
| 43 |
+
"GraphOrchestrator",
|
| 44 |
+
"create_graph_orchestrator",
|
| 45 |
+
"SearchHandlerProtocol",
|
| 46 |
+
"JudgeHandlerProtocol",
|
| 47 |
+
"Orchestrator",
|
| 48 |
+
]
|
src/orchestrator/graph_orchestrator.py
ADDED
|
@@ -0,0 +1,953 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Graph orchestrator for Phase 4.
|
| 2 |
+
|
| 3 |
+
Implements graph-based orchestration using Pydantic AI agents as nodes.
|
| 4 |
+
Supports both iterative and deep research patterns with parallel execution.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
from collections.abc import AsyncGenerator, Callable
|
| 9 |
+
from typing import TYPE_CHECKING, Any, Literal
|
| 10 |
+
|
| 11 |
+
import structlog
|
| 12 |
+
|
| 13 |
+
from src.agent_factory.agents import (
|
| 14 |
+
create_input_parser_agent,
|
| 15 |
+
create_knowledge_gap_agent,
|
| 16 |
+
create_long_writer_agent,
|
| 17 |
+
create_planner_agent,
|
| 18 |
+
create_thinking_agent,
|
| 19 |
+
create_tool_selector_agent,
|
| 20 |
+
create_writer_agent,
|
| 21 |
+
)
|
| 22 |
+
from src.agent_factory.graph_builder import (
|
| 23 |
+
AgentNode,
|
| 24 |
+
DecisionNode,
|
| 25 |
+
ParallelNode,
|
| 26 |
+
ResearchGraph,
|
| 27 |
+
StateNode,
|
| 28 |
+
create_deep_graph,
|
| 29 |
+
create_iterative_graph,
|
| 30 |
+
)
|
| 31 |
+
from src.middleware.budget_tracker import BudgetTracker
|
| 32 |
+
from src.middleware.state_machine import WorkflowState, init_workflow_state
|
| 33 |
+
from src.orchestrator.research_flow import DeepResearchFlow, IterativeResearchFlow
|
| 34 |
+
from src.utils.models import AgentEvent
|
| 35 |
+
|
| 36 |
+
if TYPE_CHECKING:
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
logger = structlog.get_logger()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class GraphExecutionContext:
|
| 43 |
+
"""Context for managing graph execution state."""
|
| 44 |
+
|
| 45 |
+
def __init__(self, state: WorkflowState, budget_tracker: BudgetTracker) -> None:
|
| 46 |
+
"""Initialize execution context.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
state: Current workflow state
|
| 50 |
+
budget_tracker: Budget tracker instance
|
| 51 |
+
"""
|
| 52 |
+
self.current_node: str = ""
|
| 53 |
+
self.visited_nodes: set[str] = set()
|
| 54 |
+
self.node_results: dict[str, Any] = {}
|
| 55 |
+
self.state = state
|
| 56 |
+
self.budget_tracker = budget_tracker
|
| 57 |
+
self.iteration_count = 0
|
| 58 |
+
|
| 59 |
+
def set_node_result(self, node_id: str, result: Any) -> None:
|
| 60 |
+
"""Store result from node execution.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
node_id: The node ID
|
| 64 |
+
result: The execution result
|
| 65 |
+
"""
|
| 66 |
+
self.node_results[node_id] = result
|
| 67 |
+
|
| 68 |
+
def get_node_result(self, node_id: str) -> Any:
|
| 69 |
+
"""Get result from node execution.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
node_id: The node ID
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
The stored result, or None if not found
|
| 76 |
+
"""
|
| 77 |
+
return self.node_results.get(node_id)
|
| 78 |
+
|
| 79 |
+
def has_visited(self, node_id: str) -> bool:
|
| 80 |
+
"""Check if node was visited.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
node_id: The node ID
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
True if visited, False otherwise
|
| 87 |
+
"""
|
| 88 |
+
return node_id in self.visited_nodes
|
| 89 |
+
|
| 90 |
+
def mark_visited(self, node_id: str) -> None:
|
| 91 |
+
"""Mark node as visited.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
node_id: The node ID
|
| 95 |
+
"""
|
| 96 |
+
self.visited_nodes.add(node_id)
|
| 97 |
+
|
| 98 |
+
def update_state(
|
| 99 |
+
self, updater: Callable[[WorkflowState, Any], WorkflowState], data: Any
|
| 100 |
+
) -> None:
|
| 101 |
+
"""Update workflow state.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
updater: Function to update state
|
| 105 |
+
data: Data to pass to updater
|
| 106 |
+
"""
|
| 107 |
+
self.state = updater(self.state, data)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class GraphOrchestrator:
|
| 111 |
+
"""
|
| 112 |
+
Graph orchestrator using Pydantic AI Graphs.
|
| 113 |
+
|
| 114 |
+
Executes research workflows as graphs with nodes (agents) and edges (transitions).
|
| 115 |
+
Supports parallel execution, conditional routing, and state management.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
mode: Literal["iterative", "deep", "auto"] = "auto",
|
| 121 |
+
max_iterations: int = 5,
|
| 122 |
+
max_time_minutes: int = 10,
|
| 123 |
+
use_graph: bool = True,
|
| 124 |
+
) -> None:
|
| 125 |
+
"""
|
| 126 |
+
Initialize graph orchestrator.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
mode: Research mode ("iterative", "deep", or "auto" to detect)
|
| 130 |
+
max_iterations: Maximum iterations per loop
|
| 131 |
+
max_time_minutes: Maximum time per loop
|
| 132 |
+
use_graph: Whether to use graph execution (True) or agent chains (False)
|
| 133 |
+
"""
|
| 134 |
+
self.mode = mode
|
| 135 |
+
self.max_iterations = max_iterations
|
| 136 |
+
self.max_time_minutes = max_time_minutes
|
| 137 |
+
self.use_graph = use_graph
|
| 138 |
+
self.logger = logger
|
| 139 |
+
|
| 140 |
+
# Initialize flows (for backward compatibility)
|
| 141 |
+
self._iterative_flow: IterativeResearchFlow | None = None
|
| 142 |
+
self._deep_flow: DeepResearchFlow | None = None
|
| 143 |
+
|
| 144 |
+
# Graph execution components (lazy initialization)
|
| 145 |
+
self._graph: ResearchGraph | None = None
|
| 146 |
+
self._budget_tracker: BudgetTracker | None = None
|
| 147 |
+
|
| 148 |
+
async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]:
|
| 149 |
+
"""
|
| 150 |
+
Run the research workflow.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
query: The user's research query
|
| 154 |
+
|
| 155 |
+
Yields:
|
| 156 |
+
AgentEvent objects for real-time UI updates
|
| 157 |
+
"""
|
| 158 |
+
self.logger.info(
|
| 159 |
+
"Starting graph orchestrator",
|
| 160 |
+
query=query[:100],
|
| 161 |
+
mode=self.mode,
|
| 162 |
+
use_graph=self.use_graph,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
yield AgentEvent(
|
| 166 |
+
type="started",
|
| 167 |
+
message=f"Starting research ({self.mode} mode): {query}",
|
| 168 |
+
iteration=0,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
# Determine research mode
|
| 173 |
+
research_mode = self.mode
|
| 174 |
+
if research_mode == "auto":
|
| 175 |
+
research_mode = await self._detect_research_mode(query)
|
| 176 |
+
|
| 177 |
+
# Use graph execution if enabled, otherwise fall back to agent chains
|
| 178 |
+
if self.use_graph:
|
| 179 |
+
async for event in self._run_with_graph(query, research_mode):
|
| 180 |
+
yield event
|
| 181 |
+
else:
|
| 182 |
+
async for event in self._run_with_chains(query, research_mode):
|
| 183 |
+
yield event
|
| 184 |
+
|
| 185 |
+
except Exception as e:
|
| 186 |
+
self.logger.error("Graph orchestrator failed", error=str(e), exc_info=True)
|
| 187 |
+
yield AgentEvent(
|
| 188 |
+
type="error",
|
| 189 |
+
message=f"Research failed: {e!s}",
|
| 190 |
+
iteration=0,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
async def _run_with_graph(
|
| 194 |
+
self, query: str, research_mode: Literal["iterative", "deep"]
|
| 195 |
+
) -> AsyncGenerator[AgentEvent, None]:
|
| 196 |
+
"""Run workflow using graph execution.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
query: The research query
|
| 200 |
+
research_mode: The research mode
|
| 201 |
+
|
| 202 |
+
Yields:
|
| 203 |
+
AgentEvent objects
|
| 204 |
+
"""
|
| 205 |
+
# Initialize state and budget tracker
|
| 206 |
+
from src.services.embeddings import get_embedding_service
|
| 207 |
+
|
| 208 |
+
embedding_service = get_embedding_service()
|
| 209 |
+
state = init_workflow_state(embedding_service=embedding_service)
|
| 210 |
+
budget_tracker = BudgetTracker()
|
| 211 |
+
budget_tracker.create_budget(
|
| 212 |
+
loop_id="graph_execution",
|
| 213 |
+
tokens_limit=100000,
|
| 214 |
+
time_limit_seconds=self.max_time_minutes * 60,
|
| 215 |
+
iterations_limit=self.max_iterations,
|
| 216 |
+
)
|
| 217 |
+
budget_tracker.start_timer("graph_execution")
|
| 218 |
+
|
| 219 |
+
context = GraphExecutionContext(state, budget_tracker)
|
| 220 |
+
|
| 221 |
+
# Build graph
|
| 222 |
+
self._graph = await self._build_graph(research_mode)
|
| 223 |
+
|
| 224 |
+
# Execute graph
|
| 225 |
+
async for event in self._execute_graph(query, context):
|
| 226 |
+
yield event
|
| 227 |
+
|
| 228 |
+
async def _run_with_chains(
|
| 229 |
+
self, query: str, research_mode: Literal["iterative", "deep"]
|
| 230 |
+
) -> AsyncGenerator[AgentEvent, None]:
|
| 231 |
+
"""Run workflow using agent chains (backward compatibility).
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
query: The research query
|
| 235 |
+
research_mode: The research mode
|
| 236 |
+
|
| 237 |
+
Yields:
|
| 238 |
+
AgentEvent objects
|
| 239 |
+
"""
|
| 240 |
+
if research_mode == "iterative":
|
| 241 |
+
yield AgentEvent(
|
| 242 |
+
type="searching",
|
| 243 |
+
message="Running iterative research flow...",
|
| 244 |
+
iteration=1,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
if self._iterative_flow is None:
|
| 248 |
+
self._iterative_flow = IterativeResearchFlow(
|
| 249 |
+
max_iterations=self.max_iterations,
|
| 250 |
+
max_time_minutes=self.max_time_minutes,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
final_report = await self._iterative_flow.run(query)
|
| 254 |
+
|
| 255 |
+
yield AgentEvent(
|
| 256 |
+
type="complete",
|
| 257 |
+
message=final_report,
|
| 258 |
+
data={"mode": "iterative"},
|
| 259 |
+
iteration=1,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
elif research_mode == "deep":
|
| 263 |
+
yield AgentEvent(
|
| 264 |
+
type="searching",
|
| 265 |
+
message="Running deep research flow...",
|
| 266 |
+
iteration=1,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
if self._deep_flow is None:
|
| 270 |
+
self._deep_flow = DeepResearchFlow(
|
| 271 |
+
max_iterations=self.max_iterations,
|
| 272 |
+
max_time_minutes=self.max_time_minutes,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
final_report = await self._deep_flow.run(query)
|
| 276 |
+
|
| 277 |
+
yield AgentEvent(
|
| 278 |
+
type="complete",
|
| 279 |
+
message=final_report,
|
| 280 |
+
data={"mode": "deep"},
|
| 281 |
+
iteration=1,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
async def _build_graph(self, mode: Literal["iterative", "deep"]) -> ResearchGraph:
|
| 285 |
+
"""Build graph for the specified mode.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
mode: Research mode
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
Constructed ResearchGraph
|
| 292 |
+
"""
|
| 293 |
+
if mode == "iterative":
|
| 294 |
+
# Get agents
|
| 295 |
+
knowledge_gap_agent = create_knowledge_gap_agent()
|
| 296 |
+
tool_selector_agent = create_tool_selector_agent()
|
| 297 |
+
thinking_agent = create_thinking_agent()
|
| 298 |
+
writer_agent = create_writer_agent()
|
| 299 |
+
|
| 300 |
+
# Create graph
|
| 301 |
+
graph = create_iterative_graph(
|
| 302 |
+
knowledge_gap_agent=knowledge_gap_agent.agent,
|
| 303 |
+
tool_selector_agent=tool_selector_agent.agent,
|
| 304 |
+
thinking_agent=thinking_agent.agent,
|
| 305 |
+
writer_agent=writer_agent.agent,
|
| 306 |
+
)
|
| 307 |
+
else: # deep
|
| 308 |
+
# Get agents
|
| 309 |
+
planner_agent = create_planner_agent()
|
| 310 |
+
knowledge_gap_agent = create_knowledge_gap_agent()
|
| 311 |
+
tool_selector_agent = create_tool_selector_agent()
|
| 312 |
+
thinking_agent = create_thinking_agent()
|
| 313 |
+
writer_agent = create_writer_agent()
|
| 314 |
+
long_writer_agent = create_long_writer_agent()
|
| 315 |
+
|
| 316 |
+
# Create graph
|
| 317 |
+
graph = create_deep_graph(
|
| 318 |
+
planner_agent=planner_agent.agent,
|
| 319 |
+
knowledge_gap_agent=knowledge_gap_agent.agent,
|
| 320 |
+
tool_selector_agent=tool_selector_agent.agent,
|
| 321 |
+
thinking_agent=thinking_agent.agent,
|
| 322 |
+
writer_agent=writer_agent.agent,
|
| 323 |
+
long_writer_agent=long_writer_agent.agent,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
return graph
|
| 327 |
+
|
| 328 |
+
def _emit_start_event(
|
| 329 |
+
self, node: Any, current_node_id: str, iteration: int, context: GraphExecutionContext
|
| 330 |
+
) -> AgentEvent:
|
| 331 |
+
"""Emit start event for a node.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
node: The node being executed
|
| 335 |
+
current_node_id: Current node ID
|
| 336 |
+
iteration: Current iteration number
|
| 337 |
+
context: Execution context
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
AgentEvent for the start of node execution
|
| 341 |
+
"""
|
| 342 |
+
if node and node.node_id == "planner":
|
| 343 |
+
return AgentEvent(
|
| 344 |
+
type="searching",
|
| 345 |
+
message="Creating report plan...",
|
| 346 |
+
iteration=iteration,
|
| 347 |
+
)
|
| 348 |
+
elif node and node.node_id == "parallel_loops":
|
| 349 |
+
# Get report plan to show section count
|
| 350 |
+
report_plan = context.get_node_result("planner")
|
| 351 |
+
if report_plan and hasattr(report_plan, "report_outline"):
|
| 352 |
+
section_count = len(report_plan.report_outline)
|
| 353 |
+
return AgentEvent(
|
| 354 |
+
type="looping",
|
| 355 |
+
message=f"Running parallel research loops for {section_count} sections...",
|
| 356 |
+
iteration=iteration,
|
| 357 |
+
data={"sections": section_count},
|
| 358 |
+
)
|
| 359 |
+
return AgentEvent(
|
| 360 |
+
type="looping",
|
| 361 |
+
message="Running parallel research loops...",
|
| 362 |
+
iteration=iteration,
|
| 363 |
+
)
|
| 364 |
+
elif node and node.node_id == "synthesizer":
|
| 365 |
+
return AgentEvent(
|
| 366 |
+
type="synthesizing",
|
| 367 |
+
message="Synthesizing final report from section drafts...",
|
| 368 |
+
iteration=iteration,
|
| 369 |
+
)
|
| 370 |
+
return AgentEvent(
|
| 371 |
+
type="looping",
|
| 372 |
+
message=f"Executing node: {current_node_id}",
|
| 373 |
+
iteration=iteration,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
def _emit_completion_event(
|
| 377 |
+
self, node: Any, current_node_id: str, result: Any, iteration: int
|
| 378 |
+
) -> AgentEvent:
|
| 379 |
+
"""Emit completion event for a node.
|
| 380 |
+
|
| 381 |
+
Args:
|
| 382 |
+
node: The node that was executed
|
| 383 |
+
current_node_id: Current node ID
|
| 384 |
+
result: Node execution result
|
| 385 |
+
iteration: Current iteration number
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
AgentEvent for the completion of node execution
|
| 389 |
+
"""
|
| 390 |
+
if not node:
|
| 391 |
+
return AgentEvent(
|
| 392 |
+
type="looping",
|
| 393 |
+
message=f"Completed node: {current_node_id}",
|
| 394 |
+
iteration=iteration,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
if node.node_id == "planner":
|
| 398 |
+
if isinstance(result, dict) and "report_outline" in result:
|
| 399 |
+
section_count = len(result["report_outline"])
|
| 400 |
+
return AgentEvent(
|
| 401 |
+
type="search_complete",
|
| 402 |
+
message=f"Report plan created with {section_count} sections",
|
| 403 |
+
iteration=iteration,
|
| 404 |
+
data={"sections": section_count},
|
| 405 |
+
)
|
| 406 |
+
return AgentEvent(
|
| 407 |
+
type="search_complete",
|
| 408 |
+
message="Report plan created",
|
| 409 |
+
iteration=iteration,
|
| 410 |
+
)
|
| 411 |
+
elif node.node_id == "parallel_loops":
|
| 412 |
+
if isinstance(result, list):
|
| 413 |
+
return AgentEvent(
|
| 414 |
+
type="search_complete",
|
| 415 |
+
message=f"Completed parallel research for {len(result)} sections",
|
| 416 |
+
iteration=iteration,
|
| 417 |
+
data={"sections_completed": len(result)},
|
| 418 |
+
)
|
| 419 |
+
return AgentEvent(
|
| 420 |
+
type="search_complete",
|
| 421 |
+
message="Parallel research loops completed",
|
| 422 |
+
iteration=iteration,
|
| 423 |
+
)
|
| 424 |
+
elif node.node_id == "synthesizer":
|
| 425 |
+
return AgentEvent(
|
| 426 |
+
type="synthesizing",
|
| 427 |
+
message="Final report synthesis completed",
|
| 428 |
+
iteration=iteration,
|
| 429 |
+
)
|
| 430 |
+
return AgentEvent(
|
| 431 |
+
type="searching" if node.node_type == "agent" else "looping",
|
| 432 |
+
message=f"Completed {node.node_type} node: {current_node_id}",
|
| 433 |
+
iteration=iteration,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
async def _execute_graph(
|
| 437 |
+
self, query: str, context: GraphExecutionContext
|
| 438 |
+
) -> AsyncGenerator[AgentEvent, None]:
|
| 439 |
+
"""Execute the graph from entry node.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
query: The research query
|
| 443 |
+
context: Execution context
|
| 444 |
+
|
| 445 |
+
Yields:
|
| 446 |
+
AgentEvent objects
|
| 447 |
+
"""
|
| 448 |
+
if not self._graph:
|
| 449 |
+
raise ValueError("Graph not built")
|
| 450 |
+
|
| 451 |
+
current_node_id = self._graph.entry_node
|
| 452 |
+
iteration = 0
|
| 453 |
+
|
| 454 |
+
while current_node_id and current_node_id not in self._graph.exit_nodes:
|
| 455 |
+
# Check budget
|
| 456 |
+
if not context.budget_tracker.can_continue("graph_execution"):
|
| 457 |
+
self.logger.warning("Budget exceeded, exiting graph execution")
|
| 458 |
+
break
|
| 459 |
+
|
| 460 |
+
# Execute current node
|
| 461 |
+
iteration += 1
|
| 462 |
+
context.current_node = current_node_id
|
| 463 |
+
node = self._graph.get_node(current_node_id)
|
| 464 |
+
|
| 465 |
+
# Emit start event
|
| 466 |
+
yield self._emit_start_event(node, current_node_id, iteration, context)
|
| 467 |
+
|
| 468 |
+
try:
|
| 469 |
+
result = await self._execute_node(current_node_id, query, context)
|
| 470 |
+
context.set_node_result(current_node_id, result)
|
| 471 |
+
context.mark_visited(current_node_id)
|
| 472 |
+
|
| 473 |
+
# Yield completion event
|
| 474 |
+
yield self._emit_completion_event(node, current_node_id, result, iteration)
|
| 475 |
+
|
| 476 |
+
except Exception as e:
|
| 477 |
+
self.logger.error("Node execution failed", node_id=current_node_id, error=str(e))
|
| 478 |
+
yield AgentEvent(
|
| 479 |
+
type="error",
|
| 480 |
+
message=f"Node {current_node_id} failed: {e!s}",
|
| 481 |
+
iteration=iteration,
|
| 482 |
+
)
|
| 483 |
+
break
|
| 484 |
+
|
| 485 |
+
# Get next node(s)
|
| 486 |
+
next_nodes = self._get_next_node(current_node_id, context)
|
| 487 |
+
|
| 488 |
+
if not next_nodes:
|
| 489 |
+
# No more nodes, check if we're at exit
|
| 490 |
+
if current_node_id in self._graph.exit_nodes:
|
| 491 |
+
break
|
| 492 |
+
# Otherwise, we've reached a dead end
|
| 493 |
+
self.logger.warning("Reached dead end in graph", node_id=current_node_id)
|
| 494 |
+
break
|
| 495 |
+
|
| 496 |
+
current_node_id = next_nodes[0] # For now, take first next node (handle parallel later)
|
| 497 |
+
|
| 498 |
+
# Final event
|
| 499 |
+
final_result = context.get_node_result(current_node_id) if current_node_id else None
|
| 500 |
+
yield AgentEvent(
|
| 501 |
+
type="complete",
|
| 502 |
+
message=final_result if isinstance(final_result, str) else "Research completed",
|
| 503 |
+
data={"mode": self.mode, "iterations": iteration},
|
| 504 |
+
iteration=iteration,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
async def _execute_node(self, node_id: str, query: str, context: GraphExecutionContext) -> Any:
|
| 508 |
+
"""Execute a single node.
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
node_id: The node ID
|
| 512 |
+
query: The research query
|
| 513 |
+
context: Execution context
|
| 514 |
+
|
| 515 |
+
Returns:
|
| 516 |
+
Node execution result
|
| 517 |
+
"""
|
| 518 |
+
if not self._graph:
|
| 519 |
+
raise ValueError("Graph not built")
|
| 520 |
+
|
| 521 |
+
node = self._graph.get_node(node_id)
|
| 522 |
+
if not node:
|
| 523 |
+
raise ValueError(f"Node {node_id} not found")
|
| 524 |
+
|
| 525 |
+
if isinstance(node, AgentNode):
|
| 526 |
+
return await self._execute_agent_node(node, query, context)
|
| 527 |
+
elif isinstance(node, StateNode):
|
| 528 |
+
return await self._execute_state_node(node, query, context)
|
| 529 |
+
elif isinstance(node, DecisionNode):
|
| 530 |
+
return await self._execute_decision_node(node, query, context)
|
| 531 |
+
elif isinstance(node, ParallelNode):
|
| 532 |
+
return await self._execute_parallel_node(node, query, context)
|
| 533 |
+
else:
|
| 534 |
+
raise ValueError(f"Unknown node type: {type(node)}")
|
| 535 |
+
|
| 536 |
+
async def _execute_agent_node(
|
| 537 |
+
self, node: AgentNode, query: str, context: GraphExecutionContext
|
| 538 |
+
) -> Any:
|
| 539 |
+
"""Execute an agent node.
|
| 540 |
+
|
| 541 |
+
Special handling for deep research nodes:
|
| 542 |
+
- "planner": Takes query string, returns ReportPlan
|
| 543 |
+
- "synthesizer": Takes query + ReportPlan + section drafts, returns final report
|
| 544 |
+
|
| 545 |
+
Args:
|
| 546 |
+
node: The agent node
|
| 547 |
+
query: The research query
|
| 548 |
+
context: Execution context
|
| 549 |
+
|
| 550 |
+
Returns:
|
| 551 |
+
Agent execution result
|
| 552 |
+
"""
|
| 553 |
+
# Special handling for synthesizer node
|
| 554 |
+
if node.node_id == "synthesizer":
|
| 555 |
+
# Call LongWriterAgent.write_report() directly instead of using agent.run()
|
| 556 |
+
from src.agent_factory.agents import create_long_writer_agent
|
| 557 |
+
from src.utils.models import ReportDraft, ReportDraftSection, ReportPlan
|
| 558 |
+
|
| 559 |
+
report_plan = context.get_node_result("planner")
|
| 560 |
+
section_drafts = context.get_node_result("parallel_loops") or []
|
| 561 |
+
|
| 562 |
+
if not isinstance(report_plan, ReportPlan):
|
| 563 |
+
raise ValueError("ReportPlan not found for synthesizer")
|
| 564 |
+
|
| 565 |
+
if not section_drafts:
|
| 566 |
+
raise ValueError("Section drafts not found for synthesizer")
|
| 567 |
+
|
| 568 |
+
# Create ReportDraft from section drafts
|
| 569 |
+
report_draft = ReportDraft(
|
| 570 |
+
sections=[
|
| 571 |
+
ReportDraftSection(
|
| 572 |
+
section_title=section.title,
|
| 573 |
+
section_content=draft,
|
| 574 |
+
)
|
| 575 |
+
for section, draft in zip(
|
| 576 |
+
report_plan.report_outline, section_drafts, strict=False
|
| 577 |
+
)
|
| 578 |
+
]
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
# Get LongWriterAgent instance and call write_report directly
|
| 582 |
+
long_writer_agent = create_long_writer_agent()
|
| 583 |
+
final_report = await long_writer_agent.write_report(
|
| 584 |
+
original_query=query,
|
| 585 |
+
report_title=report_plan.report_title,
|
| 586 |
+
report_draft=report_draft,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
# Estimate tokens (rough estimate)
|
| 590 |
+
estimated_tokens = len(final_report) // 4 # Rough token estimate
|
| 591 |
+
context.budget_tracker.add_tokens("graph_execution", estimated_tokens)
|
| 592 |
+
|
| 593 |
+
return final_report
|
| 594 |
+
|
| 595 |
+
# Standard agent execution
|
| 596 |
+
# Prepare input based on node type
|
| 597 |
+
if node.node_id == "planner":
|
| 598 |
+
# Planner takes the original query
|
| 599 |
+
input_data = query
|
| 600 |
+
else:
|
| 601 |
+
# Standard: use previous node result or query
|
| 602 |
+
prev_result = context.get_node_result(context.current_node)
|
| 603 |
+
input_data = prev_result if prev_result is not None else query
|
| 604 |
+
|
| 605 |
+
# Apply input transformer if provided
|
| 606 |
+
if node.input_transformer:
|
| 607 |
+
input_data = node.input_transformer(input_data)
|
| 608 |
+
|
| 609 |
+
# Execute agent
|
| 610 |
+
result = await node.agent.run(input_data)
|
| 611 |
+
|
| 612 |
+
# Transform output if needed
|
| 613 |
+
output = result.output
|
| 614 |
+
if node.output_transformer:
|
| 615 |
+
output = node.output_transformer(output)
|
| 616 |
+
|
| 617 |
+
# Estimate and track tokens
|
| 618 |
+
if hasattr(result, "usage") and result.usage:
|
| 619 |
+
tokens = result.usage.total_tokens if hasattr(result.usage, "total_tokens") else 0
|
| 620 |
+
context.budget_tracker.add_tokens("graph_execution", tokens)
|
| 621 |
+
|
| 622 |
+
return output
|
| 623 |
+
|
| 624 |
+
async def _execute_state_node(
|
| 625 |
+
self, node: StateNode, query: str, context: GraphExecutionContext
|
| 626 |
+
) -> Any:
|
| 627 |
+
"""Execute a state node.
|
| 628 |
+
|
| 629 |
+
Special handling for deep research state nodes:
|
| 630 |
+
- "store_plan": Stores ReportPlan in context for parallel loops
|
| 631 |
+
- "collect_drafts": Stores section drafts in context for synthesizer
|
| 632 |
+
|
| 633 |
+
Args:
|
| 634 |
+
node: The state node
|
| 635 |
+
query: The research query
|
| 636 |
+
context: Execution context
|
| 637 |
+
|
| 638 |
+
Returns:
|
| 639 |
+
State update result
|
| 640 |
+
"""
|
| 641 |
+
# Get previous result for state update
|
| 642 |
+
# For "store_plan", get from planner node
|
| 643 |
+
# For "collect_drafts", get from parallel_loops node
|
| 644 |
+
if node.node_id == "store_plan":
|
| 645 |
+
prev_result = context.get_node_result("planner")
|
| 646 |
+
elif node.node_id == "collect_drafts":
|
| 647 |
+
prev_result = context.get_node_result("parallel_loops")
|
| 648 |
+
else:
|
| 649 |
+
prev_result = context.get_node_result(context.current_node)
|
| 650 |
+
|
| 651 |
+
# Update state
|
| 652 |
+
updated_state = node.state_updater(context.state, prev_result)
|
| 653 |
+
context.state = updated_state
|
| 654 |
+
|
| 655 |
+
# Store result in context for next nodes to access
|
| 656 |
+
context.set_node_result(node.node_id, prev_result)
|
| 657 |
+
|
| 658 |
+
# Read state if needed
|
| 659 |
+
if node.state_reader:
|
| 660 |
+
return node.state_reader(context.state)
|
| 661 |
+
|
| 662 |
+
return prev_result # Return the stored result for next nodes
|
| 663 |
+
|
| 664 |
+
async def _execute_decision_node(
|
| 665 |
+
self, node: DecisionNode, query: str, context: GraphExecutionContext
|
| 666 |
+
) -> str:
|
| 667 |
+
"""Execute a decision node.
|
| 668 |
+
|
| 669 |
+
Args:
|
| 670 |
+
node: The decision node
|
| 671 |
+
query: The research query
|
| 672 |
+
context: Execution context
|
| 673 |
+
|
| 674 |
+
Returns:
|
| 675 |
+
Next node ID
|
| 676 |
+
"""
|
| 677 |
+
# Get previous result for decision
|
| 678 |
+
prev_result = context.get_node_result(context.current_node)
|
| 679 |
+
|
| 680 |
+
# Make decision
|
| 681 |
+
next_node_id = node.decision_function(prev_result)
|
| 682 |
+
|
| 683 |
+
# Validate decision
|
| 684 |
+
if next_node_id not in node.options:
|
| 685 |
+
self.logger.warning(
|
| 686 |
+
"Decision function returned invalid node",
|
| 687 |
+
node_id=node.node_id,
|
| 688 |
+
returned=next_node_id,
|
| 689 |
+
options=node.options,
|
| 690 |
+
)
|
| 691 |
+
# Default to first option
|
| 692 |
+
next_node_id = node.options[0]
|
| 693 |
+
|
| 694 |
+
return next_node_id
|
| 695 |
+
|
| 696 |
+
async def _execute_parallel_node(
|
| 697 |
+
self, node: ParallelNode, query: str, context: GraphExecutionContext
|
| 698 |
+
) -> list[Any]:
|
| 699 |
+
"""Execute a parallel node.
|
| 700 |
+
|
| 701 |
+
Special handling for deep research "parallel_loops" node:
|
| 702 |
+
- Extracts report plan from previous node result
|
| 703 |
+
- Creates IterativeResearchFlow instances for each section
|
| 704 |
+
- Executes them in parallel
|
| 705 |
+
- Returns section drafts
|
| 706 |
+
|
| 707 |
+
Args:
|
| 708 |
+
node: The parallel node
|
| 709 |
+
query: The research query
|
| 710 |
+
context: Execution context
|
| 711 |
+
|
| 712 |
+
Returns:
|
| 713 |
+
List of results from parallel nodes
|
| 714 |
+
"""
|
| 715 |
+
# Special handling for deep research parallel_loops node
|
| 716 |
+
if node.node_id == "parallel_loops":
|
| 717 |
+
return await self._execute_deep_research_parallel_loops(node, query, context)
|
| 718 |
+
|
| 719 |
+
# Standard parallel node execution
|
| 720 |
+
# Execute all parallel nodes concurrently
|
| 721 |
+
tasks = [
|
| 722 |
+
self._execute_node(parallel_node_id, query, context)
|
| 723 |
+
for parallel_node_id in node.parallel_nodes
|
| 724 |
+
]
|
| 725 |
+
|
| 726 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 727 |
+
|
| 728 |
+
# Handle exceptions
|
| 729 |
+
for i, result in enumerate(results):
|
| 730 |
+
if isinstance(result, Exception):
|
| 731 |
+
self.logger.error(
|
| 732 |
+
"Parallel node execution failed",
|
| 733 |
+
node_id=node.parallel_nodes[i] if i < len(node.parallel_nodes) else "unknown",
|
| 734 |
+
error=str(result),
|
| 735 |
+
)
|
| 736 |
+
results[i] = None
|
| 737 |
+
|
| 738 |
+
# Aggregate if needed
|
| 739 |
+
if node.aggregator:
|
| 740 |
+
aggregated = node.aggregator(results)
|
| 741 |
+
# Type cast: aggregator returns Any, but we expect list[Any]
|
| 742 |
+
return list(aggregated) if isinstance(aggregated, list) else [aggregated]
|
| 743 |
+
|
| 744 |
+
return results
|
| 745 |
+
|
| 746 |
+
async def _execute_deep_research_parallel_loops(
|
| 747 |
+
self, node: ParallelNode, query: str, context: GraphExecutionContext
|
| 748 |
+
) -> list[str]:
|
| 749 |
+
"""Execute parallel iterative research loops for deep research.
|
| 750 |
+
|
| 751 |
+
Args:
|
| 752 |
+
node: The parallel node (should be "parallel_loops")
|
| 753 |
+
query: The research query
|
| 754 |
+
context: Execution context
|
| 755 |
+
|
| 756 |
+
Returns:
|
| 757 |
+
List of section draft strings
|
| 758 |
+
"""
|
| 759 |
+
from src.agent_factory.judges import create_judge_handler
|
| 760 |
+
from src.orchestrator.research_flow import IterativeResearchFlow
|
| 761 |
+
from src.utils.models import ReportPlan
|
| 762 |
+
|
| 763 |
+
# Get report plan from previous node (store_plan)
|
| 764 |
+
# The plan should be stored in context.node_results from the planner node
|
| 765 |
+
planner_result = context.get_node_result("planner")
|
| 766 |
+
if not isinstance(planner_result, ReportPlan):
|
| 767 |
+
self.logger.error(
|
| 768 |
+
"Planner result is not a ReportPlan",
|
| 769 |
+
type=type(planner_result),
|
| 770 |
+
)
|
| 771 |
+
raise ValueError("Planner must return ReportPlan for deep research")
|
| 772 |
+
|
| 773 |
+
report_plan: ReportPlan = planner_result
|
| 774 |
+
self.logger.info(
|
| 775 |
+
"Executing parallel loops for deep research",
|
| 776 |
+
sections=len(report_plan.report_outline),
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
# Create judge handler for iterative flows
|
| 780 |
+
judge_handler = create_judge_handler()
|
| 781 |
+
|
| 782 |
+
# Create and execute iterative research flows for each section
|
| 783 |
+
async def run_section_research(section_index: int) -> str:
|
| 784 |
+
"""Run iterative research for a single section."""
|
| 785 |
+
section = report_plan.report_outline[section_index]
|
| 786 |
+
|
| 787 |
+
try:
|
| 788 |
+
# Create iterative research flow
|
| 789 |
+
flow = IterativeResearchFlow(
|
| 790 |
+
max_iterations=self.max_iterations,
|
| 791 |
+
max_time_minutes=self.max_time_minutes,
|
| 792 |
+
verbose=False, # Less verbose in parallel execution
|
| 793 |
+
use_graph=False, # Use agent chains for section research
|
| 794 |
+
judge_handler=judge_handler,
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
# Run research for this section
|
| 798 |
+
section_draft = await flow.run(
|
| 799 |
+
query=section.key_question,
|
| 800 |
+
background_context=report_plan.background_context,
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
self.logger.info(
|
| 804 |
+
"Section research completed",
|
| 805 |
+
section_index=section_index,
|
| 806 |
+
section_title=section.title,
|
| 807 |
+
draft_length=len(section_draft),
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
return section_draft
|
| 811 |
+
|
| 812 |
+
except Exception as e:
|
| 813 |
+
self.logger.error(
|
| 814 |
+
"Section research failed",
|
| 815 |
+
section_index=section_index,
|
| 816 |
+
section_title=section.title,
|
| 817 |
+
error=str(e),
|
| 818 |
+
)
|
| 819 |
+
# Return empty string for failed sections
|
| 820 |
+
return f"# {section.title}\n\n[Research failed: {e!s}]"
|
| 821 |
+
|
| 822 |
+
# Execute all sections in parallel
|
| 823 |
+
section_drafts = await asyncio.gather(
|
| 824 |
+
*(run_section_research(i) for i in range(len(report_plan.report_outline))),
|
| 825 |
+
return_exceptions=True,
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
# Handle exceptions and filter None results
|
| 829 |
+
filtered_drafts: list[str] = []
|
| 830 |
+
for i, draft in enumerate(section_drafts):
|
| 831 |
+
if isinstance(draft, Exception):
|
| 832 |
+
self.logger.error(
|
| 833 |
+
"Section research exception",
|
| 834 |
+
section_index=i,
|
| 835 |
+
error=str(draft),
|
| 836 |
+
)
|
| 837 |
+
filtered_drafts.append(
|
| 838 |
+
f"# {report_plan.report_outline[i].title}\n\n[Research failed: {draft!s}]"
|
| 839 |
+
)
|
| 840 |
+
elif draft is not None:
|
| 841 |
+
# Type narrowing: after Exception check, draft is str | None
|
| 842 |
+
assert isinstance(draft, str), "Expected str after Exception check"
|
| 843 |
+
filtered_drafts.append(draft)
|
| 844 |
+
|
| 845 |
+
self.logger.info(
|
| 846 |
+
"Parallel loops completed",
|
| 847 |
+
sections=len(filtered_drafts),
|
| 848 |
+
total_sections=len(report_plan.report_outline),
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
return filtered_drafts
|
| 852 |
+
|
| 853 |
+
def _get_next_node(self, node_id: str, context: GraphExecutionContext) -> list[str]:
|
| 854 |
+
"""Get next node(s) from current node.
|
| 855 |
+
|
| 856 |
+
Args:
|
| 857 |
+
node_id: Current node ID
|
| 858 |
+
context: Execution context
|
| 859 |
+
|
| 860 |
+
Returns:
|
| 861 |
+
List of next node IDs
|
| 862 |
+
"""
|
| 863 |
+
if not self._graph:
|
| 864 |
+
return []
|
| 865 |
+
|
| 866 |
+
# Get node result for condition evaluation
|
| 867 |
+
node_result = context.get_node_result(node_id)
|
| 868 |
+
|
| 869 |
+
# Get next nodes
|
| 870 |
+
next_nodes = self._graph.get_next_nodes(node_id, context=node_result)
|
| 871 |
+
|
| 872 |
+
# If this was a decision node, use its result
|
| 873 |
+
node = self._graph.get_node(node_id)
|
| 874 |
+
if isinstance(node, DecisionNode):
|
| 875 |
+
decision_result = node_result
|
| 876 |
+
if isinstance(decision_result, str):
|
| 877 |
+
return [decision_result]
|
| 878 |
+
|
| 879 |
+
# Return next node IDs
|
| 880 |
+
return [next_node_id for next_node_id, _ in next_nodes]
|
| 881 |
+
|
| 882 |
+
async def _detect_research_mode(self, query: str) -> Literal["iterative", "deep"]:
|
| 883 |
+
"""
|
| 884 |
+
Detect research mode from query using input parser agent.
|
| 885 |
+
|
| 886 |
+
Uses input parser agent to analyze query and determine research mode.
|
| 887 |
+
Falls back to heuristic if parser fails.
|
| 888 |
+
|
| 889 |
+
Args:
|
| 890 |
+
query: The research query
|
| 891 |
+
|
| 892 |
+
Returns:
|
| 893 |
+
Detected research mode
|
| 894 |
+
"""
|
| 895 |
+
try:
|
| 896 |
+
# Use input parser agent for intelligent mode detection
|
| 897 |
+
input_parser = create_input_parser_agent()
|
| 898 |
+
parsed_query = await input_parser.parse(query)
|
| 899 |
+
self.logger.info(
|
| 900 |
+
"Research mode detected by input parser",
|
| 901 |
+
mode=parsed_query.research_mode,
|
| 902 |
+
query=query[:100],
|
| 903 |
+
)
|
| 904 |
+
return parsed_query.research_mode
|
| 905 |
+
except Exception as e:
|
| 906 |
+
# Fallback to heuristic if parser fails
|
| 907 |
+
self.logger.warning(
|
| 908 |
+
"Input parser failed, using heuristic",
|
| 909 |
+
error=str(e),
|
| 910 |
+
query=query[:100],
|
| 911 |
+
)
|
| 912 |
+
query_lower = query.lower()
|
| 913 |
+
if any(
|
| 914 |
+
keyword in query_lower
|
| 915 |
+
for keyword in [
|
| 916 |
+
"section",
|
| 917 |
+
"sections",
|
| 918 |
+
"report",
|
| 919 |
+
"outline",
|
| 920 |
+
"structure",
|
| 921 |
+
"comprehensive",
|
| 922 |
+
"analyze",
|
| 923 |
+
"analysis",
|
| 924 |
+
]
|
| 925 |
+
):
|
| 926 |
+
return "deep"
|
| 927 |
+
return "iterative"
|
| 928 |
+
|
| 929 |
+
|
| 930 |
+
def create_graph_orchestrator(
|
| 931 |
+
mode: Literal["iterative", "deep", "auto"] = "auto",
|
| 932 |
+
max_iterations: int = 5,
|
| 933 |
+
max_time_minutes: int = 10,
|
| 934 |
+
use_graph: bool = True,
|
| 935 |
+
) -> GraphOrchestrator:
|
| 936 |
+
"""
|
| 937 |
+
Factory function to create a graph orchestrator.
|
| 938 |
+
|
| 939 |
+
Args:
|
| 940 |
+
mode: Research mode
|
| 941 |
+
max_iterations: Maximum iterations per loop
|
| 942 |
+
max_time_minutes: Maximum time per loop
|
| 943 |
+
use_graph: Whether to use graph execution (True) or agent chains (False)
|
| 944 |
+
|
| 945 |
+
Returns:
|
| 946 |
+
Configured GraphOrchestrator instance
|
| 947 |
+
"""
|
| 948 |
+
return GraphOrchestrator(
|
| 949 |
+
mode=mode,
|
| 950 |
+
max_iterations=max_iterations,
|
| 951 |
+
max_time_minutes=max_time_minutes,
|
| 952 |
+
use_graph=use_graph,
|
| 953 |
+
)
|
src/orchestrator/planner_agent.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Planner agent for creating report plans with sections and background context.
|
| 2 |
+
|
| 3 |
+
Converts the folder/planner_agent.py implementation to use Pydantic AI.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import structlog
|
| 10 |
+
from pydantic_ai import Agent
|
| 11 |
+
|
| 12 |
+
from src.agent_factory.judges import get_model
|
| 13 |
+
from src.tools.crawl_adapter import crawl_website
|
| 14 |
+
from src.tools.web_search_adapter import web_search
|
| 15 |
+
from src.utils.exceptions import ConfigurationError, JudgeError
|
| 16 |
+
from src.utils.models import ReportPlan, ReportPlanSection
|
| 17 |
+
|
| 18 |
+
logger = structlog.get_logger()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# System prompt for the planner agent
|
| 22 |
+
SYSTEM_PROMPT = f"""
|
| 23 |
+
You are a research manager, managing a team of research agents. Today's date is {datetime.now().strftime("%Y-%m-%d")}.
|
| 24 |
+
Given a research query, your job is to produce an initial outline of the report (section titles and key questions),
|
| 25 |
+
as well as some background context. Each section will be assigned to a different researcher in your team who will then
|
| 26 |
+
carry out research on the section.
|
| 27 |
+
|
| 28 |
+
You will be given:
|
| 29 |
+
- An initial research query
|
| 30 |
+
|
| 31 |
+
Your task is to:
|
| 32 |
+
1. Produce 1-2 paragraphs of initial background context (if needed) on the query by running web searches or crawling websites
|
| 33 |
+
2. Produce an outline of the report that includes a list of section titles and the key question to be addressed in each section
|
| 34 |
+
3. Provide a title for the report that will be used as the main heading
|
| 35 |
+
|
| 36 |
+
Guidelines:
|
| 37 |
+
- Each section should cover a single topic/question that is independent of other sections
|
| 38 |
+
- The key question for each section should include both the NAME and DOMAIN NAME / WEBSITE (if available and applicable) if it is related to a company, product or similar
|
| 39 |
+
- The background_context should not be more than 2 paragraphs
|
| 40 |
+
- The background_context should be very specific to the query and include any information that is relevant for researchers across all sections of the report
|
| 41 |
+
- The background_context should be drawn only from web search or crawl results rather than prior knowledge (i.e. it should only be included if you have called tools)
|
| 42 |
+
- For example, if the query is about a company, the background context should include some basic information about what the company does
|
| 43 |
+
- DO NOT do more than 2 tool calls
|
| 44 |
+
|
| 45 |
+
Only output JSON. Follow the JSON schema for ReportPlan. Do not output anything else.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class PlannerAgent:
|
| 50 |
+
"""
|
| 51 |
+
Planner agent that creates report plans with sections and background context.
|
| 52 |
+
|
| 53 |
+
Uses Pydantic AI to generate structured ReportPlan output with optional
|
| 54 |
+
web search and crawl tool usage for background context.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
model: Any | None = None,
|
| 60 |
+
web_search_tool: Any | None = None,
|
| 61 |
+
crawl_tool: Any | None = None,
|
| 62 |
+
) -> None:
|
| 63 |
+
"""
|
| 64 |
+
Initialize the planner agent.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
model: Optional Pydantic AI model. If None, uses config default.
|
| 68 |
+
web_search_tool: Optional web search tool function. If None, uses default.
|
| 69 |
+
crawl_tool: Optional crawl tool function. If None, uses default.
|
| 70 |
+
"""
|
| 71 |
+
self.model = model or get_model()
|
| 72 |
+
self.web_search_tool = web_search_tool or web_search
|
| 73 |
+
self.crawl_tool = crawl_tool or crawl_website
|
| 74 |
+
self.logger = logger
|
| 75 |
+
|
| 76 |
+
# Validate tools are callable
|
| 77 |
+
if not callable(self.web_search_tool):
|
| 78 |
+
raise ConfigurationError("web_search_tool must be callable")
|
| 79 |
+
if not callable(self.crawl_tool):
|
| 80 |
+
raise ConfigurationError("crawl_tool must be callable")
|
| 81 |
+
|
| 82 |
+
# Initialize Pydantic AI Agent
|
| 83 |
+
self.agent = Agent(
|
| 84 |
+
model=self.model,
|
| 85 |
+
output_type=ReportPlan,
|
| 86 |
+
system_prompt=SYSTEM_PROMPT,
|
| 87 |
+
tools=[self.web_search_tool, self.crawl_tool],
|
| 88 |
+
retries=3,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
async def run(self, query: str) -> ReportPlan:
|
| 92 |
+
"""
|
| 93 |
+
Run the planner agent to generate a report plan.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
query: The user's research query
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
ReportPlan with sections, background context, and report title
|
| 100 |
+
|
| 101 |
+
Raises:
|
| 102 |
+
JudgeError: If planning fails after retries
|
| 103 |
+
ConfigurationError: If agent configuration is invalid
|
| 104 |
+
"""
|
| 105 |
+
self.logger.info("Starting report planning", query=query[:100])
|
| 106 |
+
|
| 107 |
+
user_message = f"QUERY: {query}"
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
# Run the agent
|
| 111 |
+
result = await self.agent.run(user_message)
|
| 112 |
+
report_plan = result.output
|
| 113 |
+
|
| 114 |
+
# Validate report plan
|
| 115 |
+
if not report_plan.report_outline:
|
| 116 |
+
self.logger.warning("Report plan has no sections", query=query[:100])
|
| 117 |
+
raise JudgeError("Report plan must have at least one section")
|
| 118 |
+
|
| 119 |
+
if not report_plan.report_title:
|
| 120 |
+
self.logger.warning("Report plan has no title", query=query[:100])
|
| 121 |
+
raise JudgeError("Report plan must have a title")
|
| 122 |
+
|
| 123 |
+
self.logger.info(
|
| 124 |
+
"Report plan created",
|
| 125 |
+
sections=len(report_plan.report_outline),
|
| 126 |
+
has_background=bool(report_plan.background_context),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
return report_plan
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
self.logger.error("Planning failed", error=str(e), query=query[:100])
|
| 133 |
+
|
| 134 |
+
# Fallback: return minimal report plan
|
| 135 |
+
if isinstance(e, JudgeError | ConfigurationError):
|
| 136 |
+
raise
|
| 137 |
+
|
| 138 |
+
# For other errors, return a minimal plan
|
| 139 |
+
return ReportPlan(
|
| 140 |
+
background_context="",
|
| 141 |
+
report_outline=[
|
| 142 |
+
ReportPlanSection(
|
| 143 |
+
title="Research Findings",
|
| 144 |
+
key_question=query,
|
| 145 |
+
)
|
| 146 |
+
],
|
| 147 |
+
report_title=f"Research Report: {query[:50]}",
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def create_planner_agent(model: Any | None = None) -> PlannerAgent:
|
| 152 |
+
"""
|
| 153 |
+
Factory function to create a planner agent.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
model: Optional Pydantic AI model. If None, uses settings default.
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Configured PlannerAgent instance
|
| 160 |
+
|
| 161 |
+
Raises:
|
| 162 |
+
ConfigurationError: If required API keys are missing
|
| 163 |
+
"""
|
| 164 |
+
try:
|
| 165 |
+
# Get model from settings if not provided
|
| 166 |
+
if model is None:
|
| 167 |
+
model = get_model()
|
| 168 |
+
|
| 169 |
+
# Create and return planner agent
|
| 170 |
+
return PlannerAgent(model=model)
|
| 171 |
+
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logger.error("Failed to create planner agent", error=str(e))
|
| 174 |
+
raise ConfigurationError(f"Failed to create planner agent: {e}") from e
|
src/orchestrator/research_flow.py
ADDED
|
@@ -0,0 +1,999 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Research flow implementations for iterative and deep research patterns.
|
| 2 |
+
|
| 3 |
+
Converts the folder/iterative_research.py and folder/deep_research.py
|
| 4 |
+
implementations to use Pydantic AI agents.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
import time
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
import structlog
|
| 12 |
+
|
| 13 |
+
from src.agent_factory.agents import (
|
| 14 |
+
create_graph_orchestrator,
|
| 15 |
+
create_knowledge_gap_agent,
|
| 16 |
+
create_long_writer_agent,
|
| 17 |
+
create_planner_agent,
|
| 18 |
+
create_proofreader_agent,
|
| 19 |
+
create_thinking_agent,
|
| 20 |
+
create_tool_selector_agent,
|
| 21 |
+
create_writer_agent,
|
| 22 |
+
)
|
| 23 |
+
from src.agent_factory.judges import create_judge_handler
|
| 24 |
+
from src.middleware.budget_tracker import BudgetTracker
|
| 25 |
+
from src.middleware.state_machine import get_workflow_state, init_workflow_state
|
| 26 |
+
from src.middleware.workflow_manager import WorkflowManager
|
| 27 |
+
from src.services.llamaindex_rag import LlamaIndexRAGService, get_rag_service
|
| 28 |
+
from src.tools.tool_executor import execute_tool_tasks
|
| 29 |
+
from src.utils.exceptions import ConfigurationError
|
| 30 |
+
from src.utils.models import (
|
| 31 |
+
AgentSelectionPlan,
|
| 32 |
+
AgentTask,
|
| 33 |
+
Citation,
|
| 34 |
+
Conversation,
|
| 35 |
+
Evidence,
|
| 36 |
+
JudgeAssessment,
|
| 37 |
+
KnowledgeGapOutput,
|
| 38 |
+
ReportDraft,
|
| 39 |
+
ReportDraftSection,
|
| 40 |
+
ReportPlan,
|
| 41 |
+
SourceName,
|
| 42 |
+
ToolAgentOutput,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
logger = structlog.get_logger()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class IterativeResearchFlow:
|
| 49 |
+
"""
|
| 50 |
+
Iterative research flow that runs a single research loop.
|
| 51 |
+
|
| 52 |
+
Pattern: Generate observations → Evaluate gaps → Select tools → Execute → Repeat
|
| 53 |
+
until research is complete or constraints are met.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
max_iterations: int = 5,
|
| 59 |
+
max_time_minutes: int = 10,
|
| 60 |
+
verbose: bool = True,
|
| 61 |
+
use_graph: bool = False,
|
| 62 |
+
judge_handler: Any | None = None,
|
| 63 |
+
) -> None:
|
| 64 |
+
"""
|
| 65 |
+
Initialize iterative research flow.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
max_iterations: Maximum number of iterations
|
| 69 |
+
max_time_minutes: Maximum time in minutes
|
| 70 |
+
verbose: Whether to log progress
|
| 71 |
+
use_graph: Whether to use graph-based execution (True) or agent chains (False)
|
| 72 |
+
"""
|
| 73 |
+
self.max_iterations = max_iterations
|
| 74 |
+
self.max_time_minutes = max_time_minutes
|
| 75 |
+
self.verbose = verbose
|
| 76 |
+
self.use_graph = use_graph
|
| 77 |
+
self.logger = logger
|
| 78 |
+
|
| 79 |
+
# Initialize agents (only needed for agent chain execution)
|
| 80 |
+
if not use_graph:
|
| 81 |
+
self.knowledge_gap_agent = create_knowledge_gap_agent()
|
| 82 |
+
self.tool_selector_agent = create_tool_selector_agent()
|
| 83 |
+
self.thinking_agent = create_thinking_agent()
|
| 84 |
+
self.writer_agent = create_writer_agent()
|
| 85 |
+
# Initialize judge handler (use provided or create new)
|
| 86 |
+
self.judge_handler = judge_handler or create_judge_handler()
|
| 87 |
+
|
| 88 |
+
# Initialize state (only needed for agent chain execution)
|
| 89 |
+
if not use_graph:
|
| 90 |
+
self.conversation = Conversation()
|
| 91 |
+
self.iteration = 0
|
| 92 |
+
self.start_time: float | None = None
|
| 93 |
+
self.should_continue = True
|
| 94 |
+
|
| 95 |
+
# Initialize budget tracker
|
| 96 |
+
self.budget_tracker = BudgetTracker()
|
| 97 |
+
self.loop_id = "iterative_flow"
|
| 98 |
+
self.budget_tracker.create_budget(
|
| 99 |
+
loop_id=self.loop_id,
|
| 100 |
+
tokens_limit=100000,
|
| 101 |
+
time_limit_seconds=max_time_minutes * 60,
|
| 102 |
+
iterations_limit=max_iterations,
|
| 103 |
+
)
|
| 104 |
+
self.budget_tracker.start_timer(self.loop_id)
|
| 105 |
+
|
| 106 |
+
# Initialize RAG service (lazy, may be None if unavailable)
|
| 107 |
+
self._rag_service: LlamaIndexRAGService | None = None
|
| 108 |
+
|
| 109 |
+
# Graph orchestrator (lazy initialization)
|
| 110 |
+
self._graph_orchestrator: Any = None
|
| 111 |
+
|
| 112 |
+
async def run(
|
| 113 |
+
self,
|
| 114 |
+
query: str,
|
| 115 |
+
background_context: str = "",
|
| 116 |
+
output_length: str = "",
|
| 117 |
+
output_instructions: str = "",
|
| 118 |
+
) -> str:
|
| 119 |
+
"""
|
| 120 |
+
Run the iterative research flow.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
query: The research query
|
| 124 |
+
background_context: Optional background context
|
| 125 |
+
output_length: Optional description of desired output length
|
| 126 |
+
output_instructions: Optional additional instructions
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Final report string
|
| 130 |
+
"""
|
| 131 |
+
if self.use_graph:
|
| 132 |
+
return await self._run_with_graph(
|
| 133 |
+
query, background_context, output_length, output_instructions
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
return await self._run_with_chains(
|
| 137 |
+
query, background_context, output_length, output_instructions
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
async def _run_with_chains(
|
| 141 |
+
self,
|
| 142 |
+
query: str,
|
| 143 |
+
background_context: str = "",
|
| 144 |
+
output_length: str = "",
|
| 145 |
+
output_instructions: str = "",
|
| 146 |
+
) -> str:
|
| 147 |
+
"""
|
| 148 |
+
Run the iterative research flow using agent chains.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
query: The research query
|
| 152 |
+
background_context: Optional background context
|
| 153 |
+
output_length: Optional description of desired output length
|
| 154 |
+
output_instructions: Optional additional instructions
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Final report string
|
| 158 |
+
"""
|
| 159 |
+
self.start_time = time.time()
|
| 160 |
+
self.logger.info("Starting iterative research (agent chains)", query=query[:100])
|
| 161 |
+
|
| 162 |
+
# Initialize conversation with first iteration
|
| 163 |
+
self.conversation.add_iteration()
|
| 164 |
+
|
| 165 |
+
# Main research loop
|
| 166 |
+
while self.should_continue and self._check_constraints():
|
| 167 |
+
self.iteration += 1
|
| 168 |
+
self.logger.info("Starting iteration", iteration=self.iteration)
|
| 169 |
+
|
| 170 |
+
# Add new iteration to conversation
|
| 171 |
+
self.conversation.add_iteration()
|
| 172 |
+
|
| 173 |
+
# 1. Generate observations
|
| 174 |
+
await self._generate_observations(query, background_context)
|
| 175 |
+
|
| 176 |
+
# 2. Evaluate gaps
|
| 177 |
+
evaluation = await self._evaluate_gaps(query, background_context)
|
| 178 |
+
|
| 179 |
+
# 3. Assess with judge (after tools execute, we'll assess again)
|
| 180 |
+
# For now, check knowledge gap evaluation
|
| 181 |
+
# After tool execution, we'll do a full judge assessment
|
| 182 |
+
|
| 183 |
+
# Check if research is complete (knowledge gap agent says complete)
|
| 184 |
+
if evaluation.research_complete:
|
| 185 |
+
self.should_continue = False
|
| 186 |
+
self.logger.info("Research marked as complete by knowledge gap agent")
|
| 187 |
+
break
|
| 188 |
+
|
| 189 |
+
# 4. Select tools for next gap
|
| 190 |
+
next_gap = evaluation.outstanding_gaps[0] if evaluation.outstanding_gaps else query
|
| 191 |
+
selection_plan = await self._select_agents(next_gap, query, background_context)
|
| 192 |
+
|
| 193 |
+
# 5. Execute tools
|
| 194 |
+
await self._execute_tools(selection_plan.tasks)
|
| 195 |
+
|
| 196 |
+
# 6. Assess evidence sufficiency with judge
|
| 197 |
+
judge_assessment = await self._assess_with_judge(query)
|
| 198 |
+
|
| 199 |
+
# Check if judge says evidence is sufficient
|
| 200 |
+
if judge_assessment.sufficient:
|
| 201 |
+
self.should_continue = False
|
| 202 |
+
self.logger.info(
|
| 203 |
+
"Research marked as complete by judge",
|
| 204 |
+
confidence=judge_assessment.confidence,
|
| 205 |
+
reasoning=judge_assessment.reasoning[:100],
|
| 206 |
+
)
|
| 207 |
+
break
|
| 208 |
+
|
| 209 |
+
# Update budget tracker
|
| 210 |
+
self.budget_tracker.increment_iteration(self.loop_id)
|
| 211 |
+
self.budget_tracker.update_timer(self.loop_id)
|
| 212 |
+
|
| 213 |
+
# Create final report
|
| 214 |
+
report = await self._create_final_report(query, output_length, output_instructions)
|
| 215 |
+
|
| 216 |
+
elapsed = time.time() - (self.start_time or time.time())
|
| 217 |
+
self.logger.info(
|
| 218 |
+
"Iterative research completed",
|
| 219 |
+
iterations=self.iteration,
|
| 220 |
+
elapsed_minutes=elapsed / 60,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
return report
|
| 224 |
+
|
| 225 |
+
async def _run_with_graph(
|
| 226 |
+
self,
|
| 227 |
+
query: str,
|
| 228 |
+
background_context: str = "",
|
| 229 |
+
output_length: str = "",
|
| 230 |
+
output_instructions: str = "",
|
| 231 |
+
) -> str:
|
| 232 |
+
"""
|
| 233 |
+
Run the iterative research flow using graph execution.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
query: The research query
|
| 237 |
+
background_context: Optional background context (currently ignored in graph execution)
|
| 238 |
+
output_length: Optional description of desired output length (currently ignored in graph execution)
|
| 239 |
+
output_instructions: Optional additional instructions (currently ignored in graph execution)
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
Final report string
|
| 243 |
+
"""
|
| 244 |
+
self.logger.info("Starting iterative research (graph execution)", query=query[:100])
|
| 245 |
+
|
| 246 |
+
# Create graph orchestrator (lazy initialization)
|
| 247 |
+
if self._graph_orchestrator is None:
|
| 248 |
+
self._graph_orchestrator = create_graph_orchestrator(
|
| 249 |
+
mode="iterative",
|
| 250 |
+
max_iterations=self.max_iterations,
|
| 251 |
+
max_time_minutes=self.max_time_minutes,
|
| 252 |
+
use_graph=True,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Run orchestrator and collect events
|
| 256 |
+
final_report = ""
|
| 257 |
+
async for event in self._graph_orchestrator.run(query):
|
| 258 |
+
if event.type == "complete":
|
| 259 |
+
final_report = event.message
|
| 260 |
+
break
|
| 261 |
+
elif event.type == "error":
|
| 262 |
+
self.logger.error("Graph execution error", error=event.message)
|
| 263 |
+
raise RuntimeError(f"Graph execution failed: {event.message}")
|
| 264 |
+
|
| 265 |
+
if not final_report:
|
| 266 |
+
self.logger.warning("No complete event received from graph orchestrator")
|
| 267 |
+
final_report = "Research completed but no report was generated."
|
| 268 |
+
|
| 269 |
+
self.logger.info("Iterative research completed (graph execution)")
|
| 270 |
+
|
| 271 |
+
return final_report
|
| 272 |
+
|
| 273 |
+
def _check_constraints(self) -> bool:
|
| 274 |
+
"""Check if we've exceeded constraints."""
|
| 275 |
+
if self.iteration >= self.max_iterations:
|
| 276 |
+
self.logger.info("Max iterations reached", max=self.max_iterations)
|
| 277 |
+
return False
|
| 278 |
+
|
| 279 |
+
if self.start_time:
|
| 280 |
+
elapsed_minutes = (time.time() - self.start_time) / 60
|
| 281 |
+
if elapsed_minutes >= self.max_time_minutes:
|
| 282 |
+
self.logger.info("Max time reached", max=self.max_time_minutes)
|
| 283 |
+
return False
|
| 284 |
+
|
| 285 |
+
# Check budget tracker
|
| 286 |
+
self.budget_tracker.update_timer(self.loop_id)
|
| 287 |
+
exceeded, reason = self.budget_tracker.check_budget(self.loop_id)
|
| 288 |
+
if exceeded:
|
| 289 |
+
self.logger.info("Budget exceeded", reason=reason)
|
| 290 |
+
return False
|
| 291 |
+
|
| 292 |
+
return True
|
| 293 |
+
|
| 294 |
+
async def _generate_observations(self, query: str, background_context: str = "") -> str:
|
| 295 |
+
"""Generate observations from current research state."""
|
| 296 |
+
# Build input prompt for token estimation
|
| 297 |
+
conversation_history = self.conversation.compile_conversation_history()
|
| 298 |
+
# Build background context section separately to avoid backslash in f-string
|
| 299 |
+
background_section = (
|
| 300 |
+
f"BACKGROUND CONTEXT:\n{background_context}\n\n" if background_context else ""
|
| 301 |
+
)
|
| 302 |
+
input_prompt = f"""
|
| 303 |
+
You are starting iteration {self.iteration} of your research process.
|
| 304 |
+
|
| 305 |
+
ORIGINAL QUERY:
|
| 306 |
+
{query}
|
| 307 |
+
|
| 308 |
+
{background_section}HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
|
| 309 |
+
{conversation_history or "No previous actions, findings or thoughts available."}
|
| 310 |
+
"""
|
| 311 |
+
|
| 312 |
+
observations = await self.thinking_agent.generate_observations(
|
| 313 |
+
query=query,
|
| 314 |
+
background_context=background_context,
|
| 315 |
+
conversation_history=conversation_history,
|
| 316 |
+
iteration=self.iteration,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# Track tokens for this iteration
|
| 320 |
+
estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(input_prompt, observations)
|
| 321 |
+
self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens)
|
| 322 |
+
self.logger.debug(
|
| 323 |
+
"Tokens tracked for thinking agent",
|
| 324 |
+
iteration=self.iteration,
|
| 325 |
+
tokens=estimated_tokens,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
self.conversation.set_latest_thought(observations)
|
| 329 |
+
return observations
|
| 330 |
+
|
| 331 |
+
async def _evaluate_gaps(self, query: str, background_context: str = "") -> KnowledgeGapOutput:
|
| 332 |
+
"""Evaluate knowledge gaps in current research."""
|
| 333 |
+
if self.start_time:
|
| 334 |
+
elapsed_minutes = (time.time() - self.start_time) / 60
|
| 335 |
+
else:
|
| 336 |
+
elapsed_minutes = 0.0
|
| 337 |
+
|
| 338 |
+
# Build input prompt for token estimation
|
| 339 |
+
conversation_history = self.conversation.compile_conversation_history()
|
| 340 |
+
background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else ""
|
| 341 |
+
input_prompt = f"""
|
| 342 |
+
Current Iteration Number: {self.iteration}
|
| 343 |
+
Time Elapsed: {elapsed_minutes:.2f} minutes of maximum {self.max_time_minutes} minutes
|
| 344 |
+
|
| 345 |
+
ORIGINAL QUERY:
|
| 346 |
+
{query}
|
| 347 |
+
|
| 348 |
+
{background}
|
| 349 |
+
|
| 350 |
+
HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
|
| 351 |
+
{conversation_history or "No previous actions, findings or thoughts available."}
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
evaluation = await self.knowledge_gap_agent.evaluate(
|
| 355 |
+
query=query,
|
| 356 |
+
background_context=background_context,
|
| 357 |
+
conversation_history=conversation_history,
|
| 358 |
+
iteration=self.iteration,
|
| 359 |
+
time_elapsed_minutes=elapsed_minutes,
|
| 360 |
+
max_time_minutes=self.max_time_minutes,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
# Track tokens for this iteration
|
| 364 |
+
evaluation_text = f"research_complete={evaluation.research_complete}, gaps={len(evaluation.outstanding_gaps)}"
|
| 365 |
+
estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(
|
| 366 |
+
input_prompt, evaluation_text
|
| 367 |
+
)
|
| 368 |
+
self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens)
|
| 369 |
+
self.logger.debug(
|
| 370 |
+
"Tokens tracked for knowledge gap agent",
|
| 371 |
+
iteration=self.iteration,
|
| 372 |
+
tokens=estimated_tokens,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
if not evaluation.research_complete and evaluation.outstanding_gaps:
|
| 376 |
+
self.conversation.set_latest_gap(evaluation.outstanding_gaps[0])
|
| 377 |
+
|
| 378 |
+
return evaluation
|
| 379 |
+
|
| 380 |
+
async def _assess_with_judge(self, query: str) -> JudgeAssessment:
|
| 381 |
+
"""Assess evidence sufficiency using JudgeHandler.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
query: The research query
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
JudgeAssessment with sufficiency evaluation
|
| 388 |
+
"""
|
| 389 |
+
state = get_workflow_state()
|
| 390 |
+
evidence = state.evidence # Get all collected evidence
|
| 391 |
+
|
| 392 |
+
self.logger.info(
|
| 393 |
+
"Assessing evidence with judge",
|
| 394 |
+
query=query[:100],
|
| 395 |
+
evidence_count=len(evidence),
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
assessment = await self.judge_handler.assess(query, evidence)
|
| 399 |
+
|
| 400 |
+
# Track tokens for judge call
|
| 401 |
+
# Estimate tokens from query + evidence + assessment
|
| 402 |
+
evidence_text = "\n".join([e.content[:500] for e in evidence[:10]]) # Sample
|
| 403 |
+
estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(
|
| 404 |
+
query + evidence_text, str(assessment.reasoning)
|
| 405 |
+
)
|
| 406 |
+
self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens)
|
| 407 |
+
|
| 408 |
+
self.logger.info(
|
| 409 |
+
"Judge assessment complete",
|
| 410 |
+
sufficient=assessment.sufficient,
|
| 411 |
+
confidence=assessment.confidence,
|
| 412 |
+
recommendation=assessment.recommendation,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
return assessment
|
| 416 |
+
|
| 417 |
+
async def _select_agents(
|
| 418 |
+
self, gap: str, query: str, background_context: str = ""
|
| 419 |
+
) -> AgentSelectionPlan:
|
| 420 |
+
"""Select tools to address knowledge gap."""
|
| 421 |
+
# Build input prompt for token estimation
|
| 422 |
+
conversation_history = self.conversation.compile_conversation_history()
|
| 423 |
+
background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else ""
|
| 424 |
+
input_prompt = f"""
|
| 425 |
+
ORIGINAL QUERY:
|
| 426 |
+
{query}
|
| 427 |
+
|
| 428 |
+
KNOWLEDGE GAP TO ADDRESS:
|
| 429 |
+
{gap}
|
| 430 |
+
|
| 431 |
+
{background}
|
| 432 |
+
|
| 433 |
+
HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
|
| 434 |
+
{conversation_history or "No previous actions, findings or thoughts available."}
|
| 435 |
+
"""
|
| 436 |
+
|
| 437 |
+
selection_plan = await self.tool_selector_agent.select_tools(
|
| 438 |
+
gap=gap,
|
| 439 |
+
query=query,
|
| 440 |
+
background_context=background_context,
|
| 441 |
+
conversation_history=conversation_history,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Track tokens for this iteration
|
| 445 |
+
selection_text = f"tasks={len(selection_plan.tasks)}, agents={[task.agent for task in selection_plan.tasks]}"
|
| 446 |
+
estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(
|
| 447 |
+
input_prompt, selection_text
|
| 448 |
+
)
|
| 449 |
+
self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens)
|
| 450 |
+
self.logger.debug(
|
| 451 |
+
"Tokens tracked for tool selector agent",
|
| 452 |
+
iteration=self.iteration,
|
| 453 |
+
tokens=estimated_tokens,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
# Store tool calls in conversation
|
| 457 |
+
tool_calls = [
|
| 458 |
+
f"[Agent] {task.agent} [Query] {task.query} [Entity] {task.entity_website or 'null'}"
|
| 459 |
+
for task in selection_plan.tasks
|
| 460 |
+
]
|
| 461 |
+
self.conversation.set_latest_tool_calls(tool_calls)
|
| 462 |
+
|
| 463 |
+
return selection_plan
|
| 464 |
+
|
| 465 |
+
def _get_rag_service(self) -> LlamaIndexRAGService | None:
|
| 466 |
+
"""
|
| 467 |
+
Get or create RAG service instance.
|
| 468 |
+
|
| 469 |
+
Returns:
|
| 470 |
+
RAG service instance, or None if unavailable
|
| 471 |
+
"""
|
| 472 |
+
if self._rag_service is None:
|
| 473 |
+
try:
|
| 474 |
+
self._rag_service = get_rag_service()
|
| 475 |
+
self.logger.info("RAG service initialized for research flow")
|
| 476 |
+
except (ConfigurationError, ImportError) as e:
|
| 477 |
+
self.logger.warning(
|
| 478 |
+
"RAG service unavailable", error=str(e), hint="OPENAI_API_KEY required"
|
| 479 |
+
)
|
| 480 |
+
return None
|
| 481 |
+
return self._rag_service
|
| 482 |
+
|
| 483 |
+
async def _execute_tools(self, tasks: list[AgentTask]) -> dict[str, ToolAgentOutput]:
|
| 484 |
+
"""Execute selected tools concurrently."""
|
| 485 |
+
try:
|
| 486 |
+
results = await execute_tool_tasks(tasks)
|
| 487 |
+
except Exception as e:
|
| 488 |
+
# Handle tool execution errors gracefully
|
| 489 |
+
self.logger.error(
|
| 490 |
+
"Tool execution failed",
|
| 491 |
+
error=str(e),
|
| 492 |
+
task_count=len(tasks),
|
| 493 |
+
exc_info=True,
|
| 494 |
+
)
|
| 495 |
+
# Return empty results to allow research flow to continue
|
| 496 |
+
# The flow can still generate a report based on previous iterations
|
| 497 |
+
results = {}
|
| 498 |
+
|
| 499 |
+
# Store findings in conversation (only if we have results)
|
| 500 |
+
evidence_list: list[Evidence] = []
|
| 501 |
+
if results:
|
| 502 |
+
findings = [result.output for result in results.values()]
|
| 503 |
+
self.conversation.set_latest_findings(findings)
|
| 504 |
+
|
| 505 |
+
# Convert tool outputs to Evidence objects and store in workflow state
|
| 506 |
+
evidence_list = self._convert_tool_outputs_to_evidence(results)
|
| 507 |
+
|
| 508 |
+
if evidence_list:
|
| 509 |
+
state = get_workflow_state()
|
| 510 |
+
added_count = state.add_evidence(evidence_list)
|
| 511 |
+
self.logger.info(
|
| 512 |
+
"Evidence added to workflow state",
|
| 513 |
+
count=added_count,
|
| 514 |
+
total_evidence=len(state.evidence),
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# Ingest evidence into RAG if available (Phase 6 requirement)
|
| 518 |
+
rag_service = self._get_rag_service()
|
| 519 |
+
if rag_service is not None:
|
| 520 |
+
try:
|
| 521 |
+
# ingest_evidence is synchronous, run in executor to avoid blocking
|
| 522 |
+
loop = asyncio.get_event_loop()
|
| 523 |
+
await loop.run_in_executor(None, rag_service.ingest_evidence, evidence_list)
|
| 524 |
+
self.logger.info(
|
| 525 |
+
"Evidence ingested into RAG",
|
| 526 |
+
count=len(evidence_list),
|
| 527 |
+
)
|
| 528 |
+
except Exception as e:
|
| 529 |
+
# Don't fail the research loop if RAG ingestion fails
|
| 530 |
+
self.logger.warning(
|
| 531 |
+
"Failed to ingest evidence into RAG",
|
| 532 |
+
error=str(e),
|
| 533 |
+
count=len(evidence_list),
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
return results
|
| 537 |
+
|
| 538 |
+
def _convert_tool_outputs_to_evidence(
|
| 539 |
+
self, tool_results: dict[str, ToolAgentOutput]
|
| 540 |
+
) -> list[Evidence]:
|
| 541 |
+
"""Convert ToolAgentOutput to Evidence objects.
|
| 542 |
+
|
| 543 |
+
Args:
|
| 544 |
+
tool_results: Dictionary of tool execution results
|
| 545 |
+
|
| 546 |
+
Returns:
|
| 547 |
+
List of Evidence objects
|
| 548 |
+
"""
|
| 549 |
+
evidence_list = []
|
| 550 |
+
for key, result in tool_results.items():
|
| 551 |
+
# Extract URLs from sources
|
| 552 |
+
if result.sources:
|
| 553 |
+
# Create one Evidence object per source URL
|
| 554 |
+
for url in result.sources:
|
| 555 |
+
# Determine source type from URL or tool name
|
| 556 |
+
# Default to "web" for unknown web sources
|
| 557 |
+
source_type: SourceName = "web"
|
| 558 |
+
if "pubmed" in url.lower() or "ncbi" in url.lower():
|
| 559 |
+
source_type = "pubmed"
|
| 560 |
+
elif "clinicaltrials" in url.lower():
|
| 561 |
+
source_type = "clinicaltrials"
|
| 562 |
+
elif "europepmc" in url.lower():
|
| 563 |
+
source_type = "europepmc"
|
| 564 |
+
elif "biorxiv" in url.lower():
|
| 565 |
+
source_type = "biorxiv"
|
| 566 |
+
elif "arxiv" in url.lower() or "preprint" in url.lower():
|
| 567 |
+
source_type = "preprint"
|
| 568 |
+
# Note: "web" is now a valid SourceName for general web sources
|
| 569 |
+
|
| 570 |
+
citation = Citation(
|
| 571 |
+
title=f"Tool Result: {key}",
|
| 572 |
+
url=url,
|
| 573 |
+
source=source_type,
|
| 574 |
+
date="n.d.",
|
| 575 |
+
authors=[],
|
| 576 |
+
)
|
| 577 |
+
# Truncate content to reasonable length for judge (1500 chars)
|
| 578 |
+
content = result.output[:1500]
|
| 579 |
+
if len(result.output) > 1500:
|
| 580 |
+
content += "... [truncated]"
|
| 581 |
+
|
| 582 |
+
evidence = Evidence(
|
| 583 |
+
content=content,
|
| 584 |
+
citation=citation,
|
| 585 |
+
relevance=0.5, # Default relevance
|
| 586 |
+
)
|
| 587 |
+
evidence_list.append(evidence)
|
| 588 |
+
else:
|
| 589 |
+
# No URLs, create a single Evidence object with tool output
|
| 590 |
+
# Use a placeholder URL based on the tool name
|
| 591 |
+
# Determine source type from tool name
|
| 592 |
+
tool_source_type: SourceName = "web" # Default for unknown sources
|
| 593 |
+
if "RAG" in key:
|
| 594 |
+
tool_source_type = "rag"
|
| 595 |
+
elif "WebSearch" in key or "SiteCrawler" in key:
|
| 596 |
+
tool_source_type = "web"
|
| 597 |
+
# "web" is now a valid SourceName for general web sources
|
| 598 |
+
|
| 599 |
+
citation = Citation(
|
| 600 |
+
title=f"Tool Result: {key}",
|
| 601 |
+
url=f"tool://{key}",
|
| 602 |
+
source=tool_source_type,
|
| 603 |
+
date="n.d.",
|
| 604 |
+
authors=[],
|
| 605 |
+
)
|
| 606 |
+
content = result.output[:1500]
|
| 607 |
+
if len(result.output) > 1500:
|
| 608 |
+
content += "... [truncated]"
|
| 609 |
+
|
| 610 |
+
evidence = Evidence(
|
| 611 |
+
content=content,
|
| 612 |
+
citation=citation,
|
| 613 |
+
relevance=0.5,
|
| 614 |
+
)
|
| 615 |
+
evidence_list.append(evidence)
|
| 616 |
+
|
| 617 |
+
return evidence_list
|
| 618 |
+
|
| 619 |
+
async def _create_final_report(
|
| 620 |
+
self, query: str, length: str = "", instructions: str = ""
|
| 621 |
+
) -> str:
|
| 622 |
+
"""Create final report from all findings."""
|
| 623 |
+
all_findings = "\n\n".join(self.conversation.get_all_findings())
|
| 624 |
+
if not all_findings:
|
| 625 |
+
all_findings = "No findings available yet."
|
| 626 |
+
|
| 627 |
+
# Build input prompt for token estimation
|
| 628 |
+
length_str = f"* The full response should be approximately {length}.\n" if length else ""
|
| 629 |
+
instructions_str = f"* {instructions}" if instructions else ""
|
| 630 |
+
guidelines_str = (
|
| 631 |
+
("\n\nGUIDELINES:\n" + length_str + instructions_str).strip("\n")
|
| 632 |
+
if length or instructions
|
| 633 |
+
else ""
|
| 634 |
+
)
|
| 635 |
+
input_prompt = f"""
|
| 636 |
+
Provide a response based on the query and findings below with as much detail as possible. {guidelines_str}
|
| 637 |
+
|
| 638 |
+
QUERY: {query}
|
| 639 |
+
|
| 640 |
+
FINDINGS:
|
| 641 |
+
{all_findings}
|
| 642 |
+
"""
|
| 643 |
+
|
| 644 |
+
report = await self.writer_agent.write_report(
|
| 645 |
+
query=query,
|
| 646 |
+
findings=all_findings,
|
| 647 |
+
output_length=length,
|
| 648 |
+
output_instructions=instructions,
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
# Track tokens for final report (not per iteration, just total)
|
| 652 |
+
estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(input_prompt, report)
|
| 653 |
+
self.budget_tracker.add_tokens(self.loop_id, estimated_tokens)
|
| 654 |
+
self.logger.debug(
|
| 655 |
+
"Tokens tracked for writer agent (final report)",
|
| 656 |
+
tokens=estimated_tokens,
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
# Note: Citation validation for markdown reports would require Evidence objects
|
| 660 |
+
# Currently, findings are strings, not Evidence objects. For full validation,
|
| 661 |
+
# consider using ResearchReport format or passing Evidence objects separately.
|
| 662 |
+
# See src/utils/citation_validator.py for markdown citation validation utilities.
|
| 663 |
+
|
| 664 |
+
return report
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
class DeepResearchFlow:
|
| 668 |
+
"""
|
| 669 |
+
Deep research flow that runs parallel iterative loops per section.
|
| 670 |
+
|
| 671 |
+
Pattern: Plan → Parallel Iterative Loops (one per section) → Synthesis
|
| 672 |
+
"""
|
| 673 |
+
|
| 674 |
+
def __init__(
|
| 675 |
+
self,
|
| 676 |
+
max_iterations: int = 5,
|
| 677 |
+
max_time_minutes: int = 10,
|
| 678 |
+
verbose: bool = True,
|
| 679 |
+
use_long_writer: bool = True,
|
| 680 |
+
use_graph: bool = False,
|
| 681 |
+
) -> None:
|
| 682 |
+
"""
|
| 683 |
+
Initialize deep research flow.
|
| 684 |
+
|
| 685 |
+
Args:
|
| 686 |
+
max_iterations: Maximum iterations per section
|
| 687 |
+
max_time_minutes: Maximum time per section
|
| 688 |
+
verbose: Whether to log progress
|
| 689 |
+
use_long_writer: Whether to use long writer (True) or proofreader (False)
|
| 690 |
+
use_graph: Whether to use graph-based execution (True) or agent chains (False)
|
| 691 |
+
"""
|
| 692 |
+
self.max_iterations = max_iterations
|
| 693 |
+
self.max_time_minutes = max_time_minutes
|
| 694 |
+
self.verbose = verbose
|
| 695 |
+
self.use_long_writer = use_long_writer
|
| 696 |
+
self.use_graph = use_graph
|
| 697 |
+
self.logger = logger
|
| 698 |
+
|
| 699 |
+
# Initialize agents (only needed for agent chain execution)
|
| 700 |
+
if not use_graph:
|
| 701 |
+
self.planner_agent = create_planner_agent()
|
| 702 |
+
self.long_writer_agent = create_long_writer_agent()
|
| 703 |
+
self.proofreader_agent = create_proofreader_agent()
|
| 704 |
+
# Initialize judge handler for section loop completion
|
| 705 |
+
self.judge_handler = create_judge_handler()
|
| 706 |
+
# Initialize budget tracker for token tracking
|
| 707 |
+
self.budget_tracker = BudgetTracker()
|
| 708 |
+
self.loop_id = "deep_research_flow"
|
| 709 |
+
self.budget_tracker.create_budget(
|
| 710 |
+
loop_id=self.loop_id,
|
| 711 |
+
tokens_limit=200000, # Higher limit for deep research
|
| 712 |
+
time_limit_seconds=max_time_minutes
|
| 713 |
+
* 60
|
| 714 |
+
* 2, # Allow more time for parallel sections
|
| 715 |
+
iterations_limit=max_iterations * 10, # Allow for multiple sections
|
| 716 |
+
)
|
| 717 |
+
self.budget_tracker.start_timer(self.loop_id)
|
| 718 |
+
|
| 719 |
+
# Graph orchestrator (lazy initialization)
|
| 720 |
+
self._graph_orchestrator: Any = None
|
| 721 |
+
|
| 722 |
+
async def run(self, query: str) -> str:
|
| 723 |
+
"""
|
| 724 |
+
Run the deep research flow.
|
| 725 |
+
|
| 726 |
+
Args:
|
| 727 |
+
query: The research query
|
| 728 |
+
|
| 729 |
+
Returns:
|
| 730 |
+
Final report string
|
| 731 |
+
"""
|
| 732 |
+
if self.use_graph:
|
| 733 |
+
return await self._run_with_graph(query)
|
| 734 |
+
else:
|
| 735 |
+
return await self._run_with_chains(query)
|
| 736 |
+
|
| 737 |
+
async def _run_with_chains(self, query: str) -> str:
|
| 738 |
+
"""
|
| 739 |
+
Run the deep research flow using agent chains.
|
| 740 |
+
|
| 741 |
+
Args:
|
| 742 |
+
query: The research query
|
| 743 |
+
|
| 744 |
+
Returns:
|
| 745 |
+
Final report string
|
| 746 |
+
"""
|
| 747 |
+
self.logger.info("Starting deep research (agent chains)", query=query[:100])
|
| 748 |
+
|
| 749 |
+
# Initialize workflow state for deep research
|
| 750 |
+
try:
|
| 751 |
+
from src.services.embeddings import get_embedding_service
|
| 752 |
+
|
| 753 |
+
embedding_service = get_embedding_service()
|
| 754 |
+
except (ImportError, Exception):
|
| 755 |
+
# If embedding service is unavailable, initialize without it
|
| 756 |
+
embedding_service = None
|
| 757 |
+
self.logger.debug("Embedding service unavailable, initializing state without it")
|
| 758 |
+
|
| 759 |
+
init_workflow_state(embedding_service=embedding_service)
|
| 760 |
+
self.logger.debug("Workflow state initialized for deep research")
|
| 761 |
+
|
| 762 |
+
# 1. Build report plan
|
| 763 |
+
report_plan = await self._build_report_plan(query)
|
| 764 |
+
self.logger.info(
|
| 765 |
+
"Report plan created",
|
| 766 |
+
sections=len(report_plan.report_outline),
|
| 767 |
+
title=report_plan.report_title,
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
# 2. Run parallel research loops with state synchronization
|
| 771 |
+
section_drafts = await self._run_research_loops(report_plan)
|
| 772 |
+
|
| 773 |
+
# Verify state synchronization - log evidence count
|
| 774 |
+
state = get_workflow_state()
|
| 775 |
+
self.logger.info(
|
| 776 |
+
"State synchronization complete",
|
| 777 |
+
total_evidence=len(state.evidence),
|
| 778 |
+
sections_completed=len(section_drafts),
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
# 3. Create final report
|
| 782 |
+
final_report = await self._create_final_report(query, report_plan, section_drafts)
|
| 783 |
+
|
| 784 |
+
self.logger.info(
|
| 785 |
+
"Deep research completed",
|
| 786 |
+
sections=len(section_drafts),
|
| 787 |
+
final_report_length=len(final_report),
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
return final_report
|
| 791 |
+
|
| 792 |
+
async def _run_with_graph(self, query: str) -> str:
|
| 793 |
+
"""
|
| 794 |
+
Run the deep research flow using graph execution.
|
| 795 |
+
|
| 796 |
+
Args:
|
| 797 |
+
query: The research query
|
| 798 |
+
|
| 799 |
+
Returns:
|
| 800 |
+
Final report string
|
| 801 |
+
"""
|
| 802 |
+
self.logger.info("Starting deep research (graph execution)", query=query[:100])
|
| 803 |
+
|
| 804 |
+
# Create graph orchestrator (lazy initialization)
|
| 805 |
+
if self._graph_orchestrator is None:
|
| 806 |
+
self._graph_orchestrator = create_graph_orchestrator(
|
| 807 |
+
mode="deep",
|
| 808 |
+
max_iterations=self.max_iterations,
|
| 809 |
+
max_time_minutes=self.max_time_minutes,
|
| 810 |
+
use_graph=True,
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
# Run orchestrator and collect events
|
| 814 |
+
final_report = ""
|
| 815 |
+
async for event in self._graph_orchestrator.run(query):
|
| 816 |
+
if event.type == "complete":
|
| 817 |
+
final_report = event.message
|
| 818 |
+
break
|
| 819 |
+
elif event.type == "error":
|
| 820 |
+
self.logger.error("Graph execution error", error=event.message)
|
| 821 |
+
raise RuntimeError(f"Graph execution failed: {event.message}")
|
| 822 |
+
|
| 823 |
+
if not final_report:
|
| 824 |
+
self.logger.warning("No complete event received from graph orchestrator")
|
| 825 |
+
final_report = "Research completed but no report was generated."
|
| 826 |
+
|
| 827 |
+
self.logger.info("Deep research completed (graph execution)")
|
| 828 |
+
|
| 829 |
+
return final_report
|
| 830 |
+
|
| 831 |
+
async def _build_report_plan(self, query: str) -> ReportPlan:
|
| 832 |
+
"""Build the initial report plan."""
|
| 833 |
+
self.logger.info("Building report plan")
|
| 834 |
+
|
| 835 |
+
# Build input prompt for token estimation
|
| 836 |
+
input_prompt = f"QUERY: {query}"
|
| 837 |
+
|
| 838 |
+
report_plan = await self.planner_agent.run(query)
|
| 839 |
+
|
| 840 |
+
# Track tokens for planner agent
|
| 841 |
+
if not self.use_graph and hasattr(self, "budget_tracker"):
|
| 842 |
+
plan_text = (
|
| 843 |
+
f"title={report_plan.report_title}, sections={len(report_plan.report_outline)}"
|
| 844 |
+
)
|
| 845 |
+
estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(input_prompt, plan_text)
|
| 846 |
+
self.budget_tracker.add_tokens(self.loop_id, estimated_tokens)
|
| 847 |
+
self.logger.debug(
|
| 848 |
+
"Tokens tracked for planner agent",
|
| 849 |
+
tokens=estimated_tokens,
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
self.logger.info(
|
| 853 |
+
"Report plan created",
|
| 854 |
+
sections=len(report_plan.report_outline),
|
| 855 |
+
has_background=bool(report_plan.background_context),
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
return report_plan
|
| 859 |
+
|
| 860 |
+
async def _run_research_loops(self, report_plan: ReportPlan) -> list[str]:
|
| 861 |
+
"""Run parallel iterative research loops for each section."""
|
| 862 |
+
self.logger.info("Running research loops", sections=len(report_plan.report_outline))
|
| 863 |
+
|
| 864 |
+
# Create workflow manager for parallel execution
|
| 865 |
+
workflow_manager = WorkflowManager()
|
| 866 |
+
|
| 867 |
+
# Create loop configurations
|
| 868 |
+
loop_configs = [
|
| 869 |
+
{
|
| 870 |
+
"loop_id": f"section_{i}",
|
| 871 |
+
"query": section.key_question,
|
| 872 |
+
"section_title": section.title,
|
| 873 |
+
"background_context": report_plan.background_context,
|
| 874 |
+
}
|
| 875 |
+
for i, section in enumerate(report_plan.report_outline)
|
| 876 |
+
]
|
| 877 |
+
|
| 878 |
+
async def run_research_for_section(config: dict[str, Any]) -> str:
|
| 879 |
+
"""Run iterative research for a single section."""
|
| 880 |
+
loop_id = config.get("loop_id", "unknown")
|
| 881 |
+
query = config.get("query", "")
|
| 882 |
+
background_context = config.get("background_context", "")
|
| 883 |
+
|
| 884 |
+
try:
|
| 885 |
+
# Update loop status
|
| 886 |
+
await workflow_manager.update_loop_status(loop_id, "running")
|
| 887 |
+
|
| 888 |
+
# Create iterative research flow
|
| 889 |
+
flow = IterativeResearchFlow(
|
| 890 |
+
max_iterations=self.max_iterations,
|
| 891 |
+
max_time_minutes=self.max_time_minutes,
|
| 892 |
+
verbose=self.verbose,
|
| 893 |
+
use_graph=self.use_graph,
|
| 894 |
+
judge_handler=self.judge_handler if not self.use_graph else None,
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
# Run research
|
| 898 |
+
result = await flow.run(
|
| 899 |
+
query=query,
|
| 900 |
+
background_context=background_context,
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
# Sync evidence from flow to loop
|
| 904 |
+
state = get_workflow_state()
|
| 905 |
+
if state.evidence:
|
| 906 |
+
await workflow_manager.add_loop_evidence(loop_id, state.evidence)
|
| 907 |
+
|
| 908 |
+
# Update loop status
|
| 909 |
+
await workflow_manager.update_loop_status(loop_id, "completed")
|
| 910 |
+
|
| 911 |
+
return result
|
| 912 |
+
|
| 913 |
+
except Exception as e:
|
| 914 |
+
error_msg = str(e)
|
| 915 |
+
await workflow_manager.update_loop_status(loop_id, "failed", error=error_msg)
|
| 916 |
+
self.logger.error(
|
| 917 |
+
"Section research failed",
|
| 918 |
+
loop_id=loop_id,
|
| 919 |
+
error=error_msg,
|
| 920 |
+
)
|
| 921 |
+
raise
|
| 922 |
+
|
| 923 |
+
# Run all sections in parallel using workflow manager
|
| 924 |
+
section_drafts = await workflow_manager.run_loops_parallel(
|
| 925 |
+
loop_configs=loop_configs,
|
| 926 |
+
loop_func=run_research_for_section,
|
| 927 |
+
judge_handler=self.judge_handler if not self.use_graph else None,
|
| 928 |
+
budget_tracker=self.budget_tracker if not self.use_graph else None,
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
# Sync evidence from all loops to global state
|
| 932 |
+
for config in loop_configs:
|
| 933 |
+
loop_id = config.get("loop_id")
|
| 934 |
+
if loop_id:
|
| 935 |
+
await workflow_manager.sync_loop_evidence_to_state(loop_id)
|
| 936 |
+
|
| 937 |
+
# Filter out None results (failed loops)
|
| 938 |
+
section_drafts = [draft for draft in section_drafts if draft is not None]
|
| 939 |
+
|
| 940 |
+
self.logger.info(
|
| 941 |
+
"Research loops completed",
|
| 942 |
+
drafts=len(section_drafts),
|
| 943 |
+
total_sections=len(report_plan.report_outline),
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
return section_drafts
|
| 947 |
+
|
| 948 |
+
async def _create_final_report(
|
| 949 |
+
self, query: str, report_plan: ReportPlan, section_drafts: list[str]
|
| 950 |
+
) -> str:
|
| 951 |
+
"""Create final report from section drafts."""
|
| 952 |
+
self.logger.info("Creating final report")
|
| 953 |
+
|
| 954 |
+
# Create ReportDraft from section drafts
|
| 955 |
+
report_draft = ReportDraft(
|
| 956 |
+
sections=[
|
| 957 |
+
ReportDraftSection(
|
| 958 |
+
section_title=section.title,
|
| 959 |
+
section_content=draft,
|
| 960 |
+
)
|
| 961 |
+
for section, draft in zip(report_plan.report_outline, section_drafts, strict=False)
|
| 962 |
+
]
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
# Build input prompt for token estimation
|
| 966 |
+
draft_text = "\n".join(
|
| 967 |
+
[s.section_content[:500] for s in report_draft.sections[:5]]
|
| 968 |
+
) # Sample
|
| 969 |
+
input_prompt = f"QUERY: {query}\nTITLE: {report_plan.report_title}\nDRAFT: {draft_text}"
|
| 970 |
+
|
| 971 |
+
if self.use_long_writer:
|
| 972 |
+
# Use long writer agent
|
| 973 |
+
final_report = await self.long_writer_agent.write_report(
|
| 974 |
+
original_query=query,
|
| 975 |
+
report_title=report_plan.report_title,
|
| 976 |
+
report_draft=report_draft,
|
| 977 |
+
)
|
| 978 |
+
else:
|
| 979 |
+
# Use proofreader agent
|
| 980 |
+
final_report = await self.proofreader_agent.proofread(
|
| 981 |
+
query=query,
|
| 982 |
+
report_draft=report_draft,
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
# Track tokens for final report synthesis
|
| 986 |
+
if not self.use_graph and hasattr(self, "budget_tracker"):
|
| 987 |
+
estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(
|
| 988 |
+
input_prompt, final_report
|
| 989 |
+
)
|
| 990 |
+
self.budget_tracker.add_tokens(self.loop_id, estimated_tokens)
|
| 991 |
+
self.logger.debug(
|
| 992 |
+
"Tokens tracked for final report synthesis",
|
| 993 |
+
tokens=estimated_tokens,
|
| 994 |
+
agent="long_writer" if self.use_long_writer else "proofreader",
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
self.logger.info("Final report created", length=len(final_report))
|
| 998 |
+
|
| 999 |
+
return final_report
|
src/orchestrator_factory.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
|
| 3 |
from typing import Any, Literal
|
| 4 |
|
| 5 |
-
from src.
|
| 6 |
from src.utils.models import OrchestratorConfig
|
| 7 |
|
| 8 |
|
|
|
|
| 2 |
|
| 3 |
from typing import Any, Literal
|
| 4 |
|
| 5 |
+
from src.legacy_orchestrator import JudgeHandlerProtocol, Orchestrator, SearchHandlerProtocol
|
| 6 |
from src.utils.models import OrchestratorConfig
|
| 7 |
|
| 8 |
|
src/tools/__init__.py
CHANGED
|
@@ -2,7 +2,14 @@
|
|
| 2 |
|
| 3 |
from src.tools.base import SearchTool
|
| 4 |
from src.tools.pubmed import PubMedTool
|
|
|
|
| 5 |
from src.tools.search_handler import SearchHandler
|
| 6 |
|
| 7 |
# Re-export
|
| 8 |
-
__all__ = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from src.tools.base import SearchTool
|
| 4 |
from src.tools.pubmed import PubMedTool
|
| 5 |
+
from src.tools.rag_tool import RAGTool, create_rag_tool
|
| 6 |
from src.tools.search_handler import SearchHandler
|
| 7 |
|
| 8 |
# Re-export
|
| 9 |
+
__all__ = [
|
| 10 |
+
"PubMedTool",
|
| 11 |
+
"SearchHandler",
|
| 12 |
+
"SearchTool",
|
| 13 |
+
"RAGTool",
|
| 14 |
+
"create_rag_tool",
|
| 15 |
+
]
|
src/tools/crawl_adapter.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Website crawl tool adapter for Pydantic AI agents.
|
| 2 |
+
|
| 3 |
+
Adapts the folder/tools/crawl_website.py implementation to work with Pydantic AI.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import structlog
|
| 7 |
+
|
| 8 |
+
logger = structlog.get_logger()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
async def crawl_website(starting_url: str) -> str:
|
| 12 |
+
"""
|
| 13 |
+
Crawl a website starting from the given URL and return formatted results.
|
| 14 |
+
|
| 15 |
+
Use this tool to crawl a website for information relevant to the query.
|
| 16 |
+
Provide a starting URL as input.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
starting_url: The starting URL to crawl (e.g., "https://example.com")
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Formatted string with crawled content including titles, descriptions, and URLs
|
| 23 |
+
"""
|
| 24 |
+
try:
|
| 25 |
+
# Lazy import to avoid requiring folder/ dependencies at import time
|
| 26 |
+
from folder.tools.crawl_website import crawl_website as crawl_tool
|
| 27 |
+
|
| 28 |
+
# Call the tool function
|
| 29 |
+
# The tool returns List[ScrapeResult] or str
|
| 30 |
+
results = await crawl_tool(starting_url)
|
| 31 |
+
|
| 32 |
+
if isinstance(results, str):
|
| 33 |
+
# Error message returned
|
| 34 |
+
logger.warning("Crawl returned error", error=results)
|
| 35 |
+
return results
|
| 36 |
+
|
| 37 |
+
if not results:
|
| 38 |
+
return f"No content found when crawling: {starting_url}"
|
| 39 |
+
|
| 40 |
+
# Format results for agent consumption
|
| 41 |
+
formatted = [f"Found {len(results)} pages from {starting_url}:\n"]
|
| 42 |
+
for i, result in enumerate(results[:10], 1): # Limit to 10 pages
|
| 43 |
+
formatted.append(f"{i}. **{result.title or 'Untitled'}**")
|
| 44 |
+
if result.description:
|
| 45 |
+
formatted.append(f" {result.description[:200]}...")
|
| 46 |
+
formatted.append(f" URL: {result.url}")
|
| 47 |
+
if result.text:
|
| 48 |
+
formatted.append(f" Content: {result.text[:500]}...")
|
| 49 |
+
formatted.append("")
|
| 50 |
+
|
| 51 |
+
return "\n".join(formatted)
|
| 52 |
+
|
| 53 |
+
except ImportError as e:
|
| 54 |
+
logger.error("Crawl tool not available", error=str(e))
|
| 55 |
+
return f"Crawl tool not available: {e!s}"
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.error("Crawl failed", error=str(e), url=starting_url)
|
| 58 |
+
return f"Error crawling website: {e!s}"
|
src/tools/rag_tool.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RAG tool for semantic search within collected evidence.
|
| 2 |
+
|
| 3 |
+
Implements SearchTool protocol to enable RAG as a search option in the research workflow.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import TYPE_CHECKING, Any
|
| 7 |
+
|
| 8 |
+
import structlog
|
| 9 |
+
|
| 10 |
+
from src.utils.exceptions import ConfigurationError
|
| 11 |
+
from src.utils.models import Citation, Evidence, SourceName
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from src.services.llamaindex_rag import LlamaIndexRAGService
|
| 15 |
+
|
| 16 |
+
logger = structlog.get_logger()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RAGTool:
|
| 20 |
+
"""Search tool that uses LlamaIndex RAG for semantic search within collected evidence.
|
| 21 |
+
|
| 22 |
+
Wraps LlamaIndexRAGService to implement the SearchTool protocol.
|
| 23 |
+
Returns Evidence objects from RAG retrieval results.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, rag_service: "LlamaIndexRAGService | None" = None) -> None:
|
| 27 |
+
"""
|
| 28 |
+
Initialize RAG tool.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
rag_service: Optional RAG service instance. If None, will be lazy-initialized.
|
| 32 |
+
"""
|
| 33 |
+
self._rag_service = rag_service
|
| 34 |
+
self.logger = logger
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def name(self) -> str:
|
| 38 |
+
"""Return the tool name."""
|
| 39 |
+
return "rag"
|
| 40 |
+
|
| 41 |
+
def _get_rag_service(self) -> "LlamaIndexRAGService":
|
| 42 |
+
"""
|
| 43 |
+
Get or create RAG service instance.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
LlamaIndexRAGService instance
|
| 47 |
+
|
| 48 |
+
Raises:
|
| 49 |
+
ConfigurationError: If RAG service cannot be initialized
|
| 50 |
+
"""
|
| 51 |
+
if self._rag_service is None:
|
| 52 |
+
try:
|
| 53 |
+
from src.services.llamaindex_rag import get_rag_service
|
| 54 |
+
|
| 55 |
+
self._rag_service = get_rag_service()
|
| 56 |
+
self.logger.info("RAG service initialized")
|
| 57 |
+
except (ConfigurationError, ImportError) as e:
|
| 58 |
+
self.logger.error("Failed to initialize RAG service", error=str(e))
|
| 59 |
+
raise ConfigurationError("RAG service unavailable. OPENAI_API_KEY required.") from e
|
| 60 |
+
|
| 61 |
+
return self._rag_service
|
| 62 |
+
|
| 63 |
+
async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
|
| 64 |
+
"""
|
| 65 |
+
Search RAG system and return evidence.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
query: The search query string
|
| 69 |
+
max_results: Maximum number of results to return
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
List of Evidence objects from RAG retrieval
|
| 73 |
+
|
| 74 |
+
Note:
|
| 75 |
+
Returns empty list on error (does not raise exceptions).
|
| 76 |
+
"""
|
| 77 |
+
try:
|
| 78 |
+
rag_service = self._get_rag_service()
|
| 79 |
+
except ConfigurationError:
|
| 80 |
+
self.logger.warning("RAG service unavailable, returning empty results")
|
| 81 |
+
return []
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
# Retrieve documents from RAG
|
| 85 |
+
retrieved_docs = rag_service.retrieve(query, top_k=max_results)
|
| 86 |
+
|
| 87 |
+
if not retrieved_docs:
|
| 88 |
+
self.logger.info("No RAG results found", query=query[:50])
|
| 89 |
+
return []
|
| 90 |
+
|
| 91 |
+
# Convert retrieved documents to Evidence objects
|
| 92 |
+
evidence_list: list[Evidence] = []
|
| 93 |
+
for doc in retrieved_docs:
|
| 94 |
+
try:
|
| 95 |
+
evidence = self._doc_to_evidence(doc)
|
| 96 |
+
evidence_list.append(evidence)
|
| 97 |
+
except Exception as e:
|
| 98 |
+
self.logger.warning(
|
| 99 |
+
"Failed to convert document to evidence",
|
| 100 |
+
error=str(e),
|
| 101 |
+
doc_text=doc.get("text", "")[:50],
|
| 102 |
+
)
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
self.logger.info(
|
| 106 |
+
"RAG search completed",
|
| 107 |
+
query=query[:50],
|
| 108 |
+
results=len(evidence_list),
|
| 109 |
+
)
|
| 110 |
+
return evidence_list
|
| 111 |
+
|
| 112 |
+
except Exception as e:
|
| 113 |
+
self.logger.error("RAG search failed", error=str(e), query=query[:50])
|
| 114 |
+
# Return empty list on error (graceful degradation)
|
| 115 |
+
return []
|
| 116 |
+
|
| 117 |
+
def _doc_to_evidence(self, doc: dict[str, Any]) -> Evidence:
|
| 118 |
+
"""
|
| 119 |
+
Convert RAG document to Evidence object.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
doc: Document dict with keys: text, score, metadata
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
Evidence object
|
| 126 |
+
|
| 127 |
+
Raises:
|
| 128 |
+
ValueError: If document is missing required fields
|
| 129 |
+
"""
|
| 130 |
+
text = doc.get("text", "")
|
| 131 |
+
if not text:
|
| 132 |
+
raise ValueError("Document missing text content")
|
| 133 |
+
|
| 134 |
+
metadata = doc.get("metadata", {})
|
| 135 |
+
score = doc.get("score", 0.0)
|
| 136 |
+
|
| 137 |
+
# Extract citation information from metadata
|
| 138 |
+
source: SourceName = "rag" # RAG is the source
|
| 139 |
+
title = metadata.get("title", "Untitled")
|
| 140 |
+
url = metadata.get("url", "")
|
| 141 |
+
date = metadata.get("date", "Unknown")
|
| 142 |
+
authors_str = metadata.get("authors", "")
|
| 143 |
+
authors = [a.strip() for a in authors_str.split(",") if a.strip()] if authors_str else []
|
| 144 |
+
|
| 145 |
+
# Create citation
|
| 146 |
+
citation = Citation(
|
| 147 |
+
source=source,
|
| 148 |
+
title=title[:500], # Enforce max length
|
| 149 |
+
url=url,
|
| 150 |
+
date=date,
|
| 151 |
+
authors=authors,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Create evidence with relevance score (normalize score to 0-1 if needed)
|
| 155 |
+
relevance = min(max(float(score), 0.0), 1.0) if score else 0.0
|
| 156 |
+
|
| 157 |
+
return Evidence(
|
| 158 |
+
content=text,
|
| 159 |
+
citation=citation,
|
| 160 |
+
relevance=relevance,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def create_rag_tool(
|
| 165 |
+
rag_service: "LlamaIndexRAGService | None" = None,
|
| 166 |
+
) -> RAGTool:
|
| 167 |
+
"""
|
| 168 |
+
Factory function to create a RAG tool.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
rag_service: Optional RAG service instance. If None, will be lazy-initialized.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
Configured RAGTool instance
|
| 175 |
+
|
| 176 |
+
Raises:
|
| 177 |
+
ConfigurationError: If RAG service cannot be initialized and rag_service is None
|
| 178 |
+
"""
|
| 179 |
+
try:
|
| 180 |
+
return RAGTool(rag_service=rag_service)
|
| 181 |
+
except Exception as e:
|
| 182 |
+
logger.error("Failed to create RAG tool", error=str(e))
|
| 183 |
+
raise ConfigurationError(f"Failed to create RAG tool: {e}") from e
|
src/tools/search_handler.py
CHANGED
|
@@ -1,30 +1,74 @@
|
|
| 1 |
"""Search handler - orchestrates multiple search tools."""
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
-
from typing import cast
|
| 5 |
|
| 6 |
import structlog
|
| 7 |
|
| 8 |
from src.tools.base import SearchTool
|
| 9 |
-
from src.
|
|
|
|
| 10 |
from src.utils.models import Evidence, SearchResult, SourceName
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
logger = structlog.get_logger()
|
| 13 |
|
| 14 |
|
| 15 |
class SearchHandler:
|
| 16 |
"""Orchestrates parallel searches across multiple tools."""
|
| 17 |
|
| 18 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
"""
|
| 20 |
Initialize the search handler.
|
| 21 |
|
| 22 |
Args:
|
| 23 |
tools: List of search tools to use
|
| 24 |
timeout: Timeout for each search in seconds
|
|
|
|
|
|
|
| 25 |
"""
|
| 26 |
-
self.tools = tools
|
| 27 |
self.timeout = timeout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchResult:
|
| 30 |
"""
|
|
@@ -66,7 +110,7 @@ class SearchHandler:
|
|
| 66 |
sources_searched.append(tool_name)
|
| 67 |
logger.info("Search tool succeeded", tool=tool.name, count=len(success_result))
|
| 68 |
|
| 69 |
-
|
| 70 |
query=query,
|
| 71 |
evidence=all_evidence,
|
| 72 |
sources_searched=sources_searched,
|
|
@@ -74,6 +118,24 @@ class SearchHandler:
|
|
| 74 |
errors=errors,
|
| 75 |
)
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
async def _search_with_timeout(
|
| 78 |
self,
|
| 79 |
tool: SearchTool,
|
|
|
|
| 1 |
"""Search handler - orchestrates multiple search tools."""
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
+
from typing import TYPE_CHECKING, cast
|
| 5 |
|
| 6 |
import structlog
|
| 7 |
|
| 8 |
from src.tools.base import SearchTool
|
| 9 |
+
from src.tools.rag_tool import create_rag_tool
|
| 10 |
+
from src.utils.exceptions import ConfigurationError, SearchError
|
| 11 |
from src.utils.models import Evidence, SearchResult, SourceName
|
| 12 |
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from src.services.llamaindex_rag import LlamaIndexRAGService
|
| 15 |
+
|
| 16 |
logger = structlog.get_logger()
|
| 17 |
|
| 18 |
|
| 19 |
class SearchHandler:
|
| 20 |
"""Orchestrates parallel searches across multiple tools."""
|
| 21 |
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
tools: list[SearchTool],
|
| 25 |
+
timeout: float = 30.0,
|
| 26 |
+
include_rag: bool = False,
|
| 27 |
+
auto_ingest_to_rag: bool = True,
|
| 28 |
+
) -> None:
|
| 29 |
"""
|
| 30 |
Initialize the search handler.
|
| 31 |
|
| 32 |
Args:
|
| 33 |
tools: List of search tools to use
|
| 34 |
timeout: Timeout for each search in seconds
|
| 35 |
+
include_rag: Whether to include RAG tool in searches
|
| 36 |
+
auto_ingest_to_rag: Whether to automatically ingest results into RAG
|
| 37 |
"""
|
| 38 |
+
self.tools = list(tools) # Make a copy
|
| 39 |
self.timeout = timeout
|
| 40 |
+
self.auto_ingest_to_rag = auto_ingest_to_rag
|
| 41 |
+
self._rag_service: "LlamaIndexRAGService | None" = None
|
| 42 |
+
|
| 43 |
+
if include_rag:
|
| 44 |
+
self.add_rag_tool()
|
| 45 |
+
|
| 46 |
+
def add_rag_tool(self) -> None:
|
| 47 |
+
"""Add RAG tool to the tools list if available."""
|
| 48 |
+
try:
|
| 49 |
+
rag_tool = create_rag_tool()
|
| 50 |
+
self.tools.append(rag_tool)
|
| 51 |
+
logger.info("RAG tool added to search handler")
|
| 52 |
+
except ConfigurationError:
|
| 53 |
+
logger.warning(
|
| 54 |
+
"RAG tool unavailable, not adding to search handler",
|
| 55 |
+
hint="OPENAI_API_KEY required",
|
| 56 |
+
)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error("Failed to add RAG tool", error=str(e))
|
| 59 |
+
|
| 60 |
+
def _get_rag_service(self) -> "LlamaIndexRAGService | None":
|
| 61 |
+
"""Get or create RAG service for ingestion."""
|
| 62 |
+
if self._rag_service is None and self.auto_ingest_to_rag:
|
| 63 |
+
try:
|
| 64 |
+
from src.services.llamaindex_rag import get_rag_service
|
| 65 |
+
|
| 66 |
+
self._rag_service = get_rag_service()
|
| 67 |
+
logger.info("RAG service initialized for ingestion")
|
| 68 |
+
except (ConfigurationError, ImportError):
|
| 69 |
+
logger.warning("RAG service unavailable for ingestion")
|
| 70 |
+
return None
|
| 71 |
+
return self._rag_service
|
| 72 |
|
| 73 |
async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchResult:
|
| 74 |
"""
|
|
|
|
| 110 |
sources_searched.append(tool_name)
|
| 111 |
logger.info("Search tool succeeded", tool=tool.name, count=len(success_result))
|
| 112 |
|
| 113 |
+
search_result = SearchResult(
|
| 114 |
query=query,
|
| 115 |
evidence=all_evidence,
|
| 116 |
sources_searched=sources_searched,
|
|
|
|
| 118 |
errors=errors,
|
| 119 |
)
|
| 120 |
|
| 121 |
+
# Ingest evidence into RAG if enabled and available
|
| 122 |
+
if self.auto_ingest_to_rag and all_evidence:
|
| 123 |
+
rag_service = self._get_rag_service()
|
| 124 |
+
if rag_service:
|
| 125 |
+
try:
|
| 126 |
+
# Filter out RAG-sourced evidence (avoid circular ingestion)
|
| 127 |
+
evidence_to_ingest = [e for e in all_evidence if e.citation.source != "rag"]
|
| 128 |
+
if evidence_to_ingest:
|
| 129 |
+
rag_service.ingest_evidence(evidence_to_ingest)
|
| 130 |
+
logger.info(
|
| 131 |
+
"Ingested evidence into RAG",
|
| 132 |
+
count=len(evidence_to_ingest),
|
| 133 |
+
)
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.warning("Failed to ingest evidence into RAG", error=str(e))
|
| 136 |
+
|
| 137 |
+
return search_result
|
| 138 |
+
|
| 139 |
async def _search_with_timeout(
|
| 140 |
self,
|
| 141 |
tool: SearchTool,
|
src/tools/tool_executor.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tool executor for running AgentTask objects.
|
| 2 |
+
|
| 3 |
+
Executes tool tasks selected by the tool selector agent and returns ToolAgentOutput.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import structlog
|
| 7 |
+
|
| 8 |
+
from src.tools.crawl_adapter import crawl_website
|
| 9 |
+
from src.tools.rag_tool import RAGTool, create_rag_tool
|
| 10 |
+
from src.tools.web_search_adapter import web_search
|
| 11 |
+
from src.utils.exceptions import ConfigurationError
|
| 12 |
+
from src.utils.models import AgentTask, Evidence, ToolAgentOutput
|
| 13 |
+
|
| 14 |
+
logger = structlog.get_logger()
|
| 15 |
+
|
| 16 |
+
# Module-level RAG tool instance (lazy initialization)
|
| 17 |
+
_rag_tool: RAGTool | None = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _get_rag_tool() -> RAGTool | None:
|
| 21 |
+
"""
|
| 22 |
+
Get or create RAG tool instance.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
RAGTool instance, or None if unavailable
|
| 26 |
+
"""
|
| 27 |
+
global _rag_tool
|
| 28 |
+
if _rag_tool is None:
|
| 29 |
+
try:
|
| 30 |
+
_rag_tool = create_rag_tool()
|
| 31 |
+
logger.info("RAG tool initialized")
|
| 32 |
+
except ConfigurationError:
|
| 33 |
+
logger.warning("RAG tool unavailable (OPENAI_API_KEY required)")
|
| 34 |
+
return None
|
| 35 |
+
except Exception as e:
|
| 36 |
+
logger.error("Failed to initialize RAG tool", error=str(e))
|
| 37 |
+
return None
|
| 38 |
+
return _rag_tool
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _evidence_to_text(evidence_list: list[Evidence]) -> str:
|
| 42 |
+
"""
|
| 43 |
+
Convert Evidence objects to formatted text.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
evidence_list: List of Evidence objects
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Formatted text string with citations and content
|
| 50 |
+
"""
|
| 51 |
+
if not evidence_list:
|
| 52 |
+
return "No evidence found."
|
| 53 |
+
|
| 54 |
+
formatted_parts = []
|
| 55 |
+
for i, evidence in enumerate(evidence_list, 1):
|
| 56 |
+
citation = evidence.citation
|
| 57 |
+
citation_str = f"{citation.formatted}"
|
| 58 |
+
if citation.url:
|
| 59 |
+
citation_str += f" [{citation.url}]"
|
| 60 |
+
|
| 61 |
+
formatted_parts.append(f"[{i}] {citation_str}\n\n{evidence.content}\n\n---\n")
|
| 62 |
+
|
| 63 |
+
return "\n".join(formatted_parts)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
async def execute_agent_task(task: AgentTask) -> ToolAgentOutput:
|
| 67 |
+
"""
|
| 68 |
+
Execute a single agent task and return ToolAgentOutput.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
task: AgentTask specifying which tool to use and what query to run
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
ToolAgentOutput with results and source URLs
|
| 75 |
+
"""
|
| 76 |
+
logger.info(
|
| 77 |
+
"Executing agent task",
|
| 78 |
+
agent=task.agent,
|
| 79 |
+
query=task.query[:100] if task.query else "",
|
| 80 |
+
gap=task.gap[:100] if task.gap else "",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
if task.agent == "WebSearchAgent":
|
| 85 |
+
# Use web search adapter
|
| 86 |
+
result_text = await web_search(task.query)
|
| 87 |
+
# Extract URLs from result (simple heuristic - look for http/https)
|
| 88 |
+
import re
|
| 89 |
+
|
| 90 |
+
urls = re.findall(r"https?://[^\s\)]+", result_text)
|
| 91 |
+
sources = list(set(urls)) # Deduplicate
|
| 92 |
+
|
| 93 |
+
return ToolAgentOutput(output=result_text, sources=sources)
|
| 94 |
+
|
| 95 |
+
elif task.agent == "SiteCrawlerAgent":
|
| 96 |
+
# Use crawl adapter
|
| 97 |
+
if task.entity_website:
|
| 98 |
+
starting_url = task.entity_website
|
| 99 |
+
elif task.query.startswith(("http://", "https://")):
|
| 100 |
+
starting_url = task.query
|
| 101 |
+
else:
|
| 102 |
+
# Try to construct URL from query
|
| 103 |
+
starting_url = f"https://{task.query}"
|
| 104 |
+
|
| 105 |
+
result_text = await crawl_website(starting_url)
|
| 106 |
+
# Extract URLs from result
|
| 107 |
+
import re
|
| 108 |
+
|
| 109 |
+
urls = re.findall(r"https?://[^\s\)]+", result_text)
|
| 110 |
+
sources = list(set(urls)) # Deduplicate
|
| 111 |
+
|
| 112 |
+
return ToolAgentOutput(output=result_text, sources=sources)
|
| 113 |
+
|
| 114 |
+
elif task.agent == "RAGAgent":
|
| 115 |
+
# Use RAG tool for semantic search
|
| 116 |
+
rag_tool = _get_rag_tool()
|
| 117 |
+
if rag_tool is None:
|
| 118 |
+
return ToolAgentOutput(
|
| 119 |
+
output="RAG service unavailable. OPENAI_API_KEY required.",
|
| 120 |
+
sources=[],
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Search RAG and get Evidence objects
|
| 124 |
+
evidence_list = await rag_tool.search(task.query, max_results=10)
|
| 125 |
+
|
| 126 |
+
if not evidence_list:
|
| 127 |
+
return ToolAgentOutput(
|
| 128 |
+
output="No relevant evidence found in collected research.",
|
| 129 |
+
sources=[],
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Convert Evidence to formatted text
|
| 133 |
+
result_text = _evidence_to_text(evidence_list)
|
| 134 |
+
|
| 135 |
+
# Extract URLs from evidence citations
|
| 136 |
+
sources = [evidence.citation.url for evidence in evidence_list if evidence.citation.url]
|
| 137 |
+
|
| 138 |
+
return ToolAgentOutput(output=result_text, sources=sources)
|
| 139 |
+
|
| 140 |
+
else:
|
| 141 |
+
logger.warning("Unknown agent type", agent=task.agent)
|
| 142 |
+
return ToolAgentOutput(
|
| 143 |
+
output=f"Unknown agent type: {task.agent}. Available: WebSearchAgent, SiteCrawlerAgent, RAGAgent",
|
| 144 |
+
sources=[],
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.error("Tool execution failed", error=str(e), agent=task.agent)
|
| 149 |
+
return ToolAgentOutput(
|
| 150 |
+
output=f"Error executing {task.agent} for gap '{task.gap}': {e!s}",
|
| 151 |
+
sources=[],
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
async def execute_tool_tasks(
|
| 156 |
+
tasks: list[AgentTask],
|
| 157 |
+
) -> dict[str, ToolAgentOutput]:
|
| 158 |
+
"""
|
| 159 |
+
Execute multiple agent tasks concurrently.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
tasks: List of AgentTask objects to execute
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
Dictionary mapping task keys to ToolAgentOutput results
|
| 166 |
+
"""
|
| 167 |
+
import asyncio
|
| 168 |
+
|
| 169 |
+
logger.info("Executing tool tasks", count=len(tasks))
|
| 170 |
+
|
| 171 |
+
# Create async tasks
|
| 172 |
+
async_tasks = [execute_agent_task(task) for task in tasks]
|
| 173 |
+
|
| 174 |
+
# Run concurrently
|
| 175 |
+
results_list = await asyncio.gather(*async_tasks, return_exceptions=True)
|
| 176 |
+
|
| 177 |
+
# Build results dictionary
|
| 178 |
+
results: dict[str, ToolAgentOutput] = {}
|
| 179 |
+
for i, (task, result) in enumerate(zip(tasks, results_list, strict=False)):
|
| 180 |
+
if isinstance(result, Exception):
|
| 181 |
+
logger.error("Task execution failed", error=str(result), task_index=i)
|
| 182 |
+
results[f"{task.agent}_{i}"] = ToolAgentOutput(output=f"Error: {result!s}", sources=[])
|
| 183 |
+
else:
|
| 184 |
+
# Type narrowing: result is ToolAgentOutput after Exception check
|
| 185 |
+
assert isinstance(
|
| 186 |
+
result, ToolAgentOutput
|
| 187 |
+
), "Expected ToolAgentOutput after Exception check"
|
| 188 |
+
key = f"{task.agent}_{task.gap or i}" if task.gap else f"{task.agent}_{i}"
|
| 189 |
+
results[key] = result
|
| 190 |
+
|
| 191 |
+
logger.info("Tool tasks completed", completed=len(results))
|
| 192 |
+
|
| 193 |
+
return results
|
src/tools/web_search_adapter.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Web search tool adapter for Pydantic AI agents.
|
| 2 |
+
|
| 3 |
+
Adapts the folder/tools/web_search.py implementation to work with Pydantic AI.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import structlog
|
| 7 |
+
|
| 8 |
+
logger = structlog.get_logger()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
async def web_search(query: str) -> str:
|
| 12 |
+
"""
|
| 13 |
+
Perform a web search for a given query and return formatted results.
|
| 14 |
+
|
| 15 |
+
Use this tool to search the web for information relevant to the query.
|
| 16 |
+
Provide a query with 3-6 words as input.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
query: The search query (3-6 words recommended)
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Formatted string with search results including titles, descriptions, and URLs
|
| 23 |
+
"""
|
| 24 |
+
try:
|
| 25 |
+
# Lazy import to avoid requiring folder/ dependencies at import time
|
| 26 |
+
# This will use the existing web_search tool from folder/tools
|
| 27 |
+
from folder.llm_config import create_default_config
|
| 28 |
+
from folder.tools.web_search import create_web_search_tool
|
| 29 |
+
|
| 30 |
+
config = create_default_config()
|
| 31 |
+
web_search_tool = create_web_search_tool(config)
|
| 32 |
+
|
| 33 |
+
# Call the tool function
|
| 34 |
+
# The tool returns List[ScrapeResult] or str
|
| 35 |
+
results = await web_search_tool(query)
|
| 36 |
+
|
| 37 |
+
if isinstance(results, str):
|
| 38 |
+
# Error message returned
|
| 39 |
+
logger.warning("Web search returned error", error=results)
|
| 40 |
+
return results
|
| 41 |
+
|
| 42 |
+
if not results:
|
| 43 |
+
return f"No web search results found for: {query}"
|
| 44 |
+
|
| 45 |
+
# Format results for agent consumption
|
| 46 |
+
formatted = [f"Found {len(results)} web search results:\n"]
|
| 47 |
+
for i, result in enumerate(results[:5], 1): # Limit to 5 results
|
| 48 |
+
formatted.append(f"{i}. **{result.title}**")
|
| 49 |
+
if result.description:
|
| 50 |
+
formatted.append(f" {result.description[:200]}...")
|
| 51 |
+
formatted.append(f" URL: {result.url}")
|
| 52 |
+
if result.text:
|
| 53 |
+
formatted.append(f" Content: {result.text[:300]}...")
|
| 54 |
+
formatted.append("")
|
| 55 |
+
|
| 56 |
+
return "\n".join(formatted)
|
| 57 |
+
|
| 58 |
+
except ImportError as e:
|
| 59 |
+
logger.error("Web search tool not available", error=str(e))
|
| 60 |
+
return f"Web search tool not available: {e!s}"
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.error("Web search failed", error=str(e), query=query)
|
| 63 |
+
return f"Error performing web search: {e!s}"
|
src/utils/citation_validator.py
CHANGED
|
@@ -85,3 +85,94 @@ def build_reference_from_evidence(evidence: "Evidence") -> dict[str, str]:
|
|
| 85 |
"date": evidence.citation.date or "n.d.",
|
| 86 |
"url": evidence.citation.url,
|
| 87 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
"date": evidence.citation.date or "n.d.",
|
| 86 |
"url": evidence.citation.url,
|
| 87 |
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def validate_markdown_citations(
|
| 91 |
+
markdown_report: str, evidence: list["Evidence"]
|
| 92 |
+
) -> tuple[str, int]:
|
| 93 |
+
"""Validate citations in a markdown report against collected evidence.
|
| 94 |
+
|
| 95 |
+
This function validates citations in markdown format (e.g., [1], [2]) by:
|
| 96 |
+
1. Extracting URLs from the references section
|
| 97 |
+
2. Matching them against Evidence objects
|
| 98 |
+
3. Removing invalid citations from the report
|
| 99 |
+
|
| 100 |
+
Note:
|
| 101 |
+
This is a basic validation. For full validation, use ResearchReport
|
| 102 |
+
objects with validate_references().
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
markdown_report: The markdown report string with citations
|
| 106 |
+
evidence: List of Evidence objects collected during research
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Tuple of (validated_markdown, removed_count)
|
| 110 |
+
"""
|
| 111 |
+
import re
|
| 112 |
+
|
| 113 |
+
# Build set of valid URLs from evidence
|
| 114 |
+
valid_urls = {e.citation.url for e in evidence}
|
| 115 |
+
valid_urls_lower = {url.lower() for url in valid_urls}
|
| 116 |
+
|
| 117 |
+
# Extract references section (everything after "## References" or "References:")
|
| 118 |
+
ref_section_pattern = r"(?i)(?:##\s*)?References:?\s*\n(.*?)(?=\n##|\Z)"
|
| 119 |
+
ref_match = re.search(ref_section_pattern, markdown_report, re.DOTALL)
|
| 120 |
+
|
| 121 |
+
if not ref_match:
|
| 122 |
+
# No references section found, return as-is
|
| 123 |
+
return markdown_report, 0
|
| 124 |
+
|
| 125 |
+
ref_section = ref_match.group(1)
|
| 126 |
+
ref_lines = ref_section.strip().split("\n")
|
| 127 |
+
|
| 128 |
+
# Parse references: [1] https://example.com or [1] https://example.com Title
|
| 129 |
+
valid_refs = []
|
| 130 |
+
removed_count = 0
|
| 131 |
+
|
| 132 |
+
for ref_line in ref_lines:
|
| 133 |
+
stripped_line = ref_line.strip()
|
| 134 |
+
if not stripped_line:
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
# Extract URL from reference line
|
| 138 |
+
# Pattern: [N] URL or [N] URL Title
|
| 139 |
+
url_match = re.search(r"https?://[^\s\)]+", stripped_line)
|
| 140 |
+
if url_match:
|
| 141 |
+
url = url_match.group(0).rstrip(".,;")
|
| 142 |
+
url_lower = url.lower()
|
| 143 |
+
|
| 144 |
+
# Check if URL is valid
|
| 145 |
+
if url in valid_urls or url_lower in valid_urls_lower:
|
| 146 |
+
valid_refs.append(stripped_line)
|
| 147 |
+
else:
|
| 148 |
+
removed_count += 1
|
| 149 |
+
logger.warning(
|
| 150 |
+
f"Removed invalid citation from markdown: {url[:80]}"
|
| 151 |
+
+ ("..." if len(url) > 80 else "")
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
# No URL found, keep the line (might be formatted differently)
|
| 155 |
+
valid_refs.append(stripped_line)
|
| 156 |
+
|
| 157 |
+
# Rebuild references section
|
| 158 |
+
if valid_refs:
|
| 159 |
+
new_ref_section = "\n".join(valid_refs)
|
| 160 |
+
# Replace the old references section
|
| 161 |
+
validated_markdown = (
|
| 162 |
+
markdown_report[: ref_match.start(1)]
|
| 163 |
+
+ new_ref_section
|
| 164 |
+
+ markdown_report[ref_match.end(1) :]
|
| 165 |
+
)
|
| 166 |
+
else:
|
| 167 |
+
# No valid references, remove the entire section
|
| 168 |
+
validated_markdown = (
|
| 169 |
+
markdown_report[: ref_match.start()] + markdown_report[ref_match.end() :]
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if removed_count > 0:
|
| 173 |
+
logger.info(
|
| 174 |
+
f"Citation validation removed {removed_count} invalid citations from markdown report. "
|
| 175 |
+
f"{len(valid_refs)} valid citations remain."
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
return validated_markdown, removed_count
|
src/utils/config.py
CHANGED
|
@@ -41,15 +41,65 @@ class Settings(BaseSettings):
|
|
| 41 |
default="all-MiniLM-L6-v2",
|
| 42 |
description="Local sentence-transformers model (used by EmbeddingService)",
|
| 43 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
# PubMed Configuration
|
| 46 |
ncbi_api_key: str | None = Field(
|
| 47 |
default=None, description="NCBI API key for higher rate limits"
|
| 48 |
)
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
# Agent Configuration
|
| 51 |
max_iterations: int = Field(default=10, ge=1, le=50)
|
| 52 |
search_timeout: int = Field(default=30, description="Seconds to wait for search")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
# Logging
|
| 55 |
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
|
|
@@ -58,6 +108,34 @@ class Settings(BaseSettings):
|
|
| 58 |
modal_token_id: str | None = Field(default=None, description="Modal token ID")
|
| 59 |
modal_token_secret: str | None = Field(default=None, description="Modal token secret")
|
| 60 |
chroma_db_path: str = Field(default="./chroma_db", description="ChromaDB storage path")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
@property
|
| 63 |
def modal_available(self) -> bool:
|
|
@@ -102,6 +180,26 @@ class Settings(BaseSettings):
|
|
| 102 |
"""Check if any LLM API key is available."""
|
| 103 |
return self.has_openai_key or self.has_anthropic_key
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
def get_settings() -> Settings:
|
| 107 |
"""Factory function to get settings (allows mocking in tests)."""
|
|
|
|
| 41 |
default="all-MiniLM-L6-v2",
|
| 42 |
description="Local sentence-transformers model (used by EmbeddingService)",
|
| 43 |
)
|
| 44 |
+
embedding_provider: Literal["openai", "local", "huggingface"] = Field(
|
| 45 |
+
default="local",
|
| 46 |
+
description="Embedding provider to use",
|
| 47 |
+
)
|
| 48 |
+
huggingface_embedding_model: str = Field(
|
| 49 |
+
default="sentence-transformers/all-MiniLM-L6-v2",
|
| 50 |
+
description="HuggingFace embedding model ID",
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# HuggingFace Configuration
|
| 54 |
+
huggingface_api_key: str | None = Field(
|
| 55 |
+
default=None, description="HuggingFace API token (HF_TOKEN or HUGGINGFACE_API_KEY)"
|
| 56 |
+
)
|
| 57 |
+
huggingface_model: str = Field(
|
| 58 |
+
default="meta-llama/Llama-3.1-8B-Instruct",
|
| 59 |
+
description="Default HuggingFace model ID for inference",
|
| 60 |
+
)
|
| 61 |
|
| 62 |
# PubMed Configuration
|
| 63 |
ncbi_api_key: str | None = Field(
|
| 64 |
default=None, description="NCBI API key for higher rate limits"
|
| 65 |
)
|
| 66 |
|
| 67 |
+
# Web Search Configuration
|
| 68 |
+
web_search_provider: Literal["serper", "searchxng", "brave", "tavily", "duckduckgo"] = Field(
|
| 69 |
+
default="duckduckgo",
|
| 70 |
+
description="Web search provider to use",
|
| 71 |
+
)
|
| 72 |
+
serper_api_key: str | None = Field(default=None, description="Serper API key for Google search")
|
| 73 |
+
searchxng_host: str | None = Field(default=None, description="SearchXNG host URL")
|
| 74 |
+
brave_api_key: str | None = Field(default=None, description="Brave Search API key")
|
| 75 |
+
tavily_api_key: str | None = Field(default=None, description="Tavily API key")
|
| 76 |
+
|
| 77 |
# Agent Configuration
|
| 78 |
max_iterations: int = Field(default=10, ge=1, le=50)
|
| 79 |
search_timeout: int = Field(default=30, description="Seconds to wait for search")
|
| 80 |
+
use_graph_execution: bool = Field(
|
| 81 |
+
default=False, description="Use graph-based execution for research flows"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Budget & Rate Limiting Configuration
|
| 85 |
+
default_token_limit: int = Field(
|
| 86 |
+
default=100000,
|
| 87 |
+
ge=1000,
|
| 88 |
+
le=1000000,
|
| 89 |
+
description="Default token budget per research loop",
|
| 90 |
+
)
|
| 91 |
+
default_time_limit_minutes: int = Field(
|
| 92 |
+
default=10,
|
| 93 |
+
ge=1,
|
| 94 |
+
le=120,
|
| 95 |
+
description="Default time limit per research loop (minutes)",
|
| 96 |
+
)
|
| 97 |
+
default_iterations_limit: int = Field(
|
| 98 |
+
default=10,
|
| 99 |
+
ge=1,
|
| 100 |
+
le=50,
|
| 101 |
+
description="Default iterations limit per research loop",
|
| 102 |
+
)
|
| 103 |
|
| 104 |
# Logging
|
| 105 |
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
|
|
|
|
| 108 |
modal_token_id: str | None = Field(default=None, description="Modal token ID")
|
| 109 |
modal_token_secret: str | None = Field(default=None, description="Modal token secret")
|
| 110 |
chroma_db_path: str = Field(default="./chroma_db", description="ChromaDB storage path")
|
| 111 |
+
chroma_db_persist: bool = Field(
|
| 112 |
+
default=True,
|
| 113 |
+
description="Whether to persist ChromaDB to disk",
|
| 114 |
+
)
|
| 115 |
+
chroma_db_host: str | None = Field(
|
| 116 |
+
default=None,
|
| 117 |
+
description="ChromaDB server host (for remote ChromaDB)",
|
| 118 |
+
)
|
| 119 |
+
chroma_db_port: int | None = Field(
|
| 120 |
+
default=None,
|
| 121 |
+
description="ChromaDB server port (for remote ChromaDB)",
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# RAG Service Configuration
|
| 125 |
+
rag_collection_name: str = Field(
|
| 126 |
+
default="deepcritical_evidence",
|
| 127 |
+
description="ChromaDB collection name for RAG",
|
| 128 |
+
)
|
| 129 |
+
rag_similarity_top_k: int = Field(
|
| 130 |
+
default=5,
|
| 131 |
+
ge=1,
|
| 132 |
+
le=50,
|
| 133 |
+
description="Number of top results to retrieve from RAG",
|
| 134 |
+
)
|
| 135 |
+
rag_auto_ingest: bool = Field(
|
| 136 |
+
default=True,
|
| 137 |
+
description="Automatically ingest evidence into RAG",
|
| 138 |
+
)
|
| 139 |
|
| 140 |
@property
|
| 141 |
def modal_available(self) -> bool:
|
|
|
|
| 180 |
"""Check if any LLM API key is available."""
|
| 181 |
return self.has_openai_key or self.has_anthropic_key
|
| 182 |
|
| 183 |
+
@property
|
| 184 |
+
def has_huggingface_key(self) -> bool:
|
| 185 |
+
"""Check if HuggingFace API key is available."""
|
| 186 |
+
return bool(self.huggingface_api_key)
|
| 187 |
+
|
| 188 |
+
@property
|
| 189 |
+
def web_search_available(self) -> bool:
|
| 190 |
+
"""Check if web search is available (either no-key provider or API key present)."""
|
| 191 |
+
if self.web_search_provider == "duckduckgo":
|
| 192 |
+
return True # No API key required
|
| 193 |
+
if self.web_search_provider == "serper":
|
| 194 |
+
return bool(self.serper_api_key)
|
| 195 |
+
if self.web_search_provider == "searchxng":
|
| 196 |
+
return bool(self.searchxng_host)
|
| 197 |
+
if self.web_search_provider == "brave":
|
| 198 |
+
return bool(self.brave_api_key)
|
| 199 |
+
if self.web_search_provider == "tavily":
|
| 200 |
+
return bool(self.tavily_api_key)
|
| 201 |
+
return False
|
| 202 |
+
|
| 203 |
|
| 204 |
def get_settings() -> Settings:
|
| 205 |
"""Factory function to get settings (allows mocking in tests)."""
|
src/utils/models.py
CHANGED
|
@@ -6,7 +6,7 @@ from typing import Any, ClassVar, Literal
|
|
| 6 |
from pydantic import BaseModel, Field
|
| 7 |
|
| 8 |
# Centralized source type - add new sources here (e.g., "biorxiv" in Phase 11)
|
| 9 |
-
SourceName = Literal["pubmed", "clinicaltrials", "biorxiv", "europepmc", "preprint"]
|
| 10 |
|
| 11 |
|
| 12 |
class Citation(BaseModel):
|
|
@@ -303,3 +303,269 @@ class OrchestratorConfig(BaseModel):
|
|
| 303 |
max_iterations: int = Field(default=10, ge=1, le=20)
|
| 304 |
max_results_per_tool: int = Field(default=10, ge=1, le=50)
|
| 305 |
search_timeout: float = Field(default=30.0, ge=5.0, le=120.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from pydantic import BaseModel, Field
|
| 7 |
|
| 8 |
# Centralized source type - add new sources here (e.g., "biorxiv" in Phase 11)
|
| 9 |
+
SourceName = Literal["pubmed", "clinicaltrials", "biorxiv", "europepmc", "preprint", "rag", "web"]
|
| 10 |
|
| 11 |
|
| 12 |
class Citation(BaseModel):
|
|
|
|
| 303 |
max_iterations: int = Field(default=10, ge=1, le=20)
|
| 304 |
max_results_per_tool: int = Field(default=10, ge=1, le=50)
|
| 305 |
search_timeout: float = Field(default=30.0, ge=5.0, le=120.0)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# Models for iterative/deep research patterns
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class IterationData(BaseModel):
|
| 312 |
+
"""Data for a single iteration of the research loop."""
|
| 313 |
+
|
| 314 |
+
gap: str = Field(description="The gap addressed in the iteration", default="")
|
| 315 |
+
tool_calls: list[str] = Field(description="The tool calls made", default_factory=list)
|
| 316 |
+
findings: list[str] = Field(
|
| 317 |
+
description="The findings collected from tool calls", default_factory=list
|
| 318 |
+
)
|
| 319 |
+
thought: str = Field(
|
| 320 |
+
description="The thinking done to reflect on the success of the iteration and next steps",
|
| 321 |
+
default="",
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
model_config = {"frozen": True}
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class Conversation(BaseModel):
|
| 328 |
+
"""A conversation between the user and the iterative researcher."""
|
| 329 |
+
|
| 330 |
+
history: list[IterationData] = Field(
|
| 331 |
+
description="The data for each iteration of the research loop",
|
| 332 |
+
default_factory=list,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def add_iteration(self, iteration_data: IterationData | None = None) -> None:
|
| 336 |
+
"""Add a new iteration to the conversation history."""
|
| 337 |
+
if iteration_data is None:
|
| 338 |
+
iteration_data = IterationData()
|
| 339 |
+
self.history.append(iteration_data)
|
| 340 |
+
|
| 341 |
+
def set_latest_gap(self, gap: str) -> None:
|
| 342 |
+
"""Set the gap for the latest iteration."""
|
| 343 |
+
if not self.history:
|
| 344 |
+
self.add_iteration()
|
| 345 |
+
# Use model_copy() since IterationData is frozen
|
| 346 |
+
self.history[-1] = self.history[-1].model_copy(update={"gap": gap})
|
| 347 |
+
|
| 348 |
+
def set_latest_tool_calls(self, tool_calls: list[str]) -> None:
|
| 349 |
+
"""Set the tool calls for the latest iteration."""
|
| 350 |
+
if not self.history:
|
| 351 |
+
self.add_iteration()
|
| 352 |
+
# Use model_copy() since IterationData is frozen
|
| 353 |
+
self.history[-1] = self.history[-1].model_copy(update={"tool_calls": tool_calls})
|
| 354 |
+
|
| 355 |
+
def set_latest_findings(self, findings: list[str]) -> None:
|
| 356 |
+
"""Set the findings for the latest iteration."""
|
| 357 |
+
if not self.history:
|
| 358 |
+
self.add_iteration()
|
| 359 |
+
# Use model_copy() since IterationData is frozen
|
| 360 |
+
self.history[-1] = self.history[-1].model_copy(update={"findings": findings})
|
| 361 |
+
|
| 362 |
+
def set_latest_thought(self, thought: str) -> None:
|
| 363 |
+
"""Set the thought for the latest iteration."""
|
| 364 |
+
if not self.history:
|
| 365 |
+
self.add_iteration()
|
| 366 |
+
# Use model_copy() since IterationData is frozen
|
| 367 |
+
self.history[-1] = self.history[-1].model_copy(update={"thought": thought})
|
| 368 |
+
|
| 369 |
+
def get_latest_gap(self) -> str:
|
| 370 |
+
"""Get the gap from the latest iteration."""
|
| 371 |
+
if not self.history:
|
| 372 |
+
return ""
|
| 373 |
+
return self.history[-1].gap
|
| 374 |
+
|
| 375 |
+
def get_latest_tool_calls(self) -> list[str]:
|
| 376 |
+
"""Get the tool calls from the latest iteration."""
|
| 377 |
+
if not self.history:
|
| 378 |
+
return []
|
| 379 |
+
return self.history[-1].tool_calls
|
| 380 |
+
|
| 381 |
+
def get_latest_findings(self) -> list[str]:
|
| 382 |
+
"""Get the findings from the latest iteration."""
|
| 383 |
+
if not self.history:
|
| 384 |
+
return []
|
| 385 |
+
return self.history[-1].findings
|
| 386 |
+
|
| 387 |
+
def get_latest_thought(self) -> str:
|
| 388 |
+
"""Get the thought from the latest iteration."""
|
| 389 |
+
if not self.history:
|
| 390 |
+
return ""
|
| 391 |
+
return self.history[-1].thought
|
| 392 |
+
|
| 393 |
+
def get_all_findings(self) -> list[str]:
|
| 394 |
+
"""Get all findings from all iterations."""
|
| 395 |
+
return [finding for iteration_data in self.history for finding in iteration_data.findings]
|
| 396 |
+
|
| 397 |
+
def compile_conversation_history(self) -> str:
|
| 398 |
+
"""Compile the conversation history into a string."""
|
| 399 |
+
conversation = ""
|
| 400 |
+
for iteration_num, iteration_data in enumerate(self.history):
|
| 401 |
+
conversation += f"[ITERATION {iteration_num + 1}]\n\n"
|
| 402 |
+
if iteration_data.thought:
|
| 403 |
+
conversation += f"{self.get_thought_string(iteration_num)}\n\n"
|
| 404 |
+
if iteration_data.gap:
|
| 405 |
+
conversation += f"{self.get_task_string(iteration_num)}\n\n"
|
| 406 |
+
if iteration_data.tool_calls:
|
| 407 |
+
conversation += f"{self.get_action_string(iteration_num)}\n\n"
|
| 408 |
+
if iteration_data.findings:
|
| 409 |
+
conversation += f"{self.get_findings_string(iteration_num)}\n\n"
|
| 410 |
+
|
| 411 |
+
return conversation
|
| 412 |
+
|
| 413 |
+
def get_task_string(self, iteration_num: int) -> str:
|
| 414 |
+
"""Get the task for the specified iteration."""
|
| 415 |
+
if iteration_num < len(self.history) and self.history[iteration_num].gap:
|
| 416 |
+
return (
|
| 417 |
+
f"<task>\nAddress this knowledge gap: "
|
| 418 |
+
f"{self.history[iteration_num].gap}\n</task>"
|
| 419 |
+
)
|
| 420 |
+
return ""
|
| 421 |
+
|
| 422 |
+
def get_action_string(self, iteration_num: int) -> str:
|
| 423 |
+
"""Get the action for the specified iteration."""
|
| 424 |
+
if iteration_num < len(self.history) and self.history[iteration_num].tool_calls:
|
| 425 |
+
joined_calls = "\n".join(self.history[iteration_num].tool_calls)
|
| 426 |
+
return (
|
| 427 |
+
"<action>\nCalling the following tools to address the knowledge gap:\n"
|
| 428 |
+
f"{joined_calls}\n</action>"
|
| 429 |
+
)
|
| 430 |
+
return ""
|
| 431 |
+
|
| 432 |
+
def get_findings_string(self, iteration_num: int) -> str:
|
| 433 |
+
"""Get the findings for the specified iteration."""
|
| 434 |
+
if iteration_num < len(self.history) and self.history[iteration_num].findings:
|
| 435 |
+
joined_findings = "\n\n".join(self.history[iteration_num].findings)
|
| 436 |
+
return f"<findings>\n{joined_findings}\n</findings>"
|
| 437 |
+
return ""
|
| 438 |
+
|
| 439 |
+
def get_thought_string(self, iteration_num: int) -> str:
|
| 440 |
+
"""Get the thought for the specified iteration."""
|
| 441 |
+
if iteration_num < len(self.history) and self.history[iteration_num].thought:
|
| 442 |
+
return f"<thought>\n{self.history[iteration_num].thought}\n</thought>"
|
| 443 |
+
return ""
|
| 444 |
+
|
| 445 |
+
def latest_task_string(self) -> str:
|
| 446 |
+
"""Get the latest task."""
|
| 447 |
+
if not self.history:
|
| 448 |
+
return ""
|
| 449 |
+
return self.get_task_string(len(self.history) - 1)
|
| 450 |
+
|
| 451 |
+
def latest_action_string(self) -> str:
|
| 452 |
+
"""Get the latest action."""
|
| 453 |
+
if not self.history:
|
| 454 |
+
return ""
|
| 455 |
+
return self.get_action_string(len(self.history) - 1)
|
| 456 |
+
|
| 457 |
+
def latest_findings_string(self) -> str:
|
| 458 |
+
"""Get the latest findings."""
|
| 459 |
+
if not self.history:
|
| 460 |
+
return ""
|
| 461 |
+
return self.get_findings_string(len(self.history) - 1)
|
| 462 |
+
|
| 463 |
+
def latest_thought_string(self) -> str:
|
| 464 |
+
"""Get the latest thought."""
|
| 465 |
+
if not self.history:
|
| 466 |
+
return ""
|
| 467 |
+
return self.get_thought_string(len(self.history) - 1)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
class ReportPlanSection(BaseModel):
|
| 471 |
+
"""A section of the report that needs to be written."""
|
| 472 |
+
|
| 473 |
+
title: str = Field(description="The title of the section")
|
| 474 |
+
key_question: str = Field(description="The key question to be addressed in the section")
|
| 475 |
+
|
| 476 |
+
model_config = {"frozen": True}
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
class ReportPlan(BaseModel):
|
| 480 |
+
"""Output from the Report Planner Agent."""
|
| 481 |
+
|
| 482 |
+
background_context: str = Field(
|
| 483 |
+
description="A summary of supporting context that can be passed onto the research agents"
|
| 484 |
+
)
|
| 485 |
+
report_outline: list[ReportPlanSection] = Field(
|
| 486 |
+
description="List of sections that need to be written in the report"
|
| 487 |
+
)
|
| 488 |
+
report_title: str = Field(description="The title of the report")
|
| 489 |
+
|
| 490 |
+
model_config = {"frozen": True}
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class KnowledgeGapOutput(BaseModel):
|
| 494 |
+
"""Output from the Knowledge Gap Agent."""
|
| 495 |
+
|
| 496 |
+
research_complete: bool = Field(
|
| 497 |
+
description="Whether the research and findings are complete enough to end the research loop"
|
| 498 |
+
)
|
| 499 |
+
outstanding_gaps: list[str] = Field(
|
| 500 |
+
description="List of knowledge gaps that still need to be addressed"
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
model_config = {"frozen": True}
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
class AgentTask(BaseModel):
|
| 507 |
+
"""A task for a specific agent to address knowledge gaps."""
|
| 508 |
+
|
| 509 |
+
gap: str | None = Field(description="The knowledge gap being addressed", default=None)
|
| 510 |
+
agent: str = Field(description="The name of the agent to use")
|
| 511 |
+
query: str = Field(description="The specific query for the agent")
|
| 512 |
+
entity_website: str | None = Field(
|
| 513 |
+
description="The website of the entity being researched, if known",
|
| 514 |
+
default=None,
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
model_config = {"frozen": True}
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class AgentSelectionPlan(BaseModel):
|
| 521 |
+
"""Plan for which agents to use for knowledge gaps."""
|
| 522 |
+
|
| 523 |
+
tasks: list[AgentTask] = Field(description="List of agent tasks to address knowledge gaps")
|
| 524 |
+
|
| 525 |
+
model_config = {"frozen": True}
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class ReportDraftSection(BaseModel):
|
| 529 |
+
"""A section of the report that needs to be written."""
|
| 530 |
+
|
| 531 |
+
section_title: str = Field(description="The title of the section")
|
| 532 |
+
section_content: str = Field(description="The content of the section")
|
| 533 |
+
|
| 534 |
+
model_config = {"frozen": True}
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
class ReportDraft(BaseModel):
|
| 538 |
+
"""Output from the Report Planner Agent."""
|
| 539 |
+
|
| 540 |
+
sections: list[ReportDraftSection] = Field(
|
| 541 |
+
description="List of sections that are in the report"
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
model_config = {"frozen": True}
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
class ToolAgentOutput(BaseModel):
|
| 548 |
+
"""Standard output for all tool agents."""
|
| 549 |
+
|
| 550 |
+
output: str = Field(description="The output from the tool agent")
|
| 551 |
+
sources: list[str] = Field(description="List of source URLs", default_factory=list)
|
| 552 |
+
|
| 553 |
+
model_config = {"frozen": True}
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
class ParsedQuery(BaseModel):
|
| 557 |
+
"""Parsed and improved user query with research mode detection."""
|
| 558 |
+
|
| 559 |
+
original_query: str = Field(description="The original user query")
|
| 560 |
+
improved_query: str = Field(description="Improved/refined query")
|
| 561 |
+
research_mode: Literal["iterative", "deep"] = Field(description="Detected research mode")
|
| 562 |
+
key_entities: list[str] = Field(
|
| 563 |
+
default_factory=list,
|
| 564 |
+
description="Key entities extracted from query",
|
| 565 |
+
)
|
| 566 |
+
research_questions: list[str] = Field(
|
| 567 |
+
default_factory=list,
|
| 568 |
+
description="Specific research questions extracted",
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
model_config = {"frozen": True}
|
tests/integration/test_deep_research.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Integration tests for deep research flow.
|
| 2 |
+
|
| 3 |
+
Tests the complete deep research pattern: plan → parallel loops → synthesis.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from unittest.mock import AsyncMock, patch
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
from src.middleware.state_machine import init_workflow_state
|
| 11 |
+
from src.orchestrator.research_flow import DeepResearchFlow
|
| 12 |
+
from src.utils.models import ReportPlan, ReportPlanSection
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@pytest.mark.integration
|
| 16 |
+
class TestDeepResearchFlow:
|
| 17 |
+
"""Integration tests for DeepResearchFlow."""
|
| 18 |
+
|
| 19 |
+
@pytest.mark.asyncio
|
| 20 |
+
async def test_deep_research_creates_plan(self) -> None:
|
| 21 |
+
"""Test that deep research creates a report plan."""
|
| 22 |
+
# Initialize workflow state
|
| 23 |
+
init_workflow_state()
|
| 24 |
+
|
| 25 |
+
flow = DeepResearchFlow(
|
| 26 |
+
max_iterations=2,
|
| 27 |
+
max_time_minutes=5,
|
| 28 |
+
verbose=False,
|
| 29 |
+
use_graph=False,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# Mock the planner agent to return a simple plan
|
| 33 |
+
mock_plan = ReportPlan(
|
| 34 |
+
background_context="Test background context",
|
| 35 |
+
report_outline=[
|
| 36 |
+
ReportPlanSection(
|
| 37 |
+
title="Section 1",
|
| 38 |
+
key_question="What is the first question?",
|
| 39 |
+
),
|
| 40 |
+
ReportPlanSection(
|
| 41 |
+
title="Section 2",
|
| 42 |
+
key_question="What is the second question?",
|
| 43 |
+
),
|
| 44 |
+
],
|
| 45 |
+
report_title="Test Report",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
flow.planner_agent.run = AsyncMock(return_value=mock_plan)
|
| 49 |
+
|
| 50 |
+
# Mock the iterative research flows to return simple drafts
|
| 51 |
+
async def mock_iterative_run(query: str, **kwargs: dict) -> str:
|
| 52 |
+
return f"# Draft for: {query}\n\nThis is a test draft."
|
| 53 |
+
|
| 54 |
+
# Mock the long writer to return a simple report
|
| 55 |
+
flow.long_writer_agent.write_report = AsyncMock(
|
| 56 |
+
return_value="# Test Report\n\n## Section 1\n\nDraft 1\n\n## Section 2\n\nDraft 2"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# We can't easily mock the IterativeResearchFlow.run() without more setup
|
| 60 |
+
# So we'll test the plan creation separately
|
| 61 |
+
plan = await flow._build_report_plan("Test query")
|
| 62 |
+
|
| 63 |
+
assert isinstance(plan, ReportPlan)
|
| 64 |
+
assert plan.report_title == "Test Report"
|
| 65 |
+
assert len(plan.report_outline) == 2
|
| 66 |
+
assert plan.report_outline[0].title == "Section 1"
|
| 67 |
+
|
| 68 |
+
@pytest.mark.asyncio
|
| 69 |
+
async def test_deep_research_parallel_loops_state_synchronization(self) -> None:
|
| 70 |
+
"""Test that parallel loops properly synchronize state."""
|
| 71 |
+
# Initialize workflow state
|
| 72 |
+
state = init_workflow_state()
|
| 73 |
+
|
| 74 |
+
flow = DeepResearchFlow(
|
| 75 |
+
max_iterations=1,
|
| 76 |
+
max_time_minutes=2,
|
| 77 |
+
verbose=False,
|
| 78 |
+
use_graph=False,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Create a simple report plan
|
| 82 |
+
report_plan = ReportPlan(
|
| 83 |
+
background_context="Test background",
|
| 84 |
+
report_outline=[
|
| 85 |
+
ReportPlanSection(
|
| 86 |
+
title="Section 1",
|
| 87 |
+
key_question="Question 1?",
|
| 88 |
+
),
|
| 89 |
+
ReportPlanSection(
|
| 90 |
+
title="Section 2",
|
| 91 |
+
key_question="Question 2?",
|
| 92 |
+
),
|
| 93 |
+
],
|
| 94 |
+
report_title="Test Report",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Mock iterative research flows to add evidence to state
|
| 98 |
+
from src.utils.models import Citation, Evidence
|
| 99 |
+
|
| 100 |
+
async def mock_iterative_run(query: str, **kwargs: dict) -> str:
|
| 101 |
+
# Add evidence to state to test synchronization
|
| 102 |
+
ev = Evidence(
|
| 103 |
+
content=f"Evidence for {query}",
|
| 104 |
+
citation=Citation(
|
| 105 |
+
source="pubmed",
|
| 106 |
+
title=f"Title for {query}",
|
| 107 |
+
url=f"https://example.com/{query.replace('?', '').replace(' ', '_')}",
|
| 108 |
+
date="2024-01-01",
|
| 109 |
+
),
|
| 110 |
+
)
|
| 111 |
+
state.add_evidence([ev])
|
| 112 |
+
return f"# Draft: {query}\n\nTest content."
|
| 113 |
+
|
| 114 |
+
# Patch IterativeResearchFlow.run
|
| 115 |
+
with patch(
|
| 116 |
+
"src.orchestrator.research_flow.IterativeResearchFlow.run",
|
| 117 |
+
side_effect=mock_iterative_run,
|
| 118 |
+
):
|
| 119 |
+
section_drafts = await flow._run_research_loops(report_plan)
|
| 120 |
+
|
| 121 |
+
# Verify parallel execution
|
| 122 |
+
assert len(section_drafts) == 2
|
| 123 |
+
assert "Question 1" in section_drafts[0]
|
| 124 |
+
assert "Question 2" in section_drafts[1]
|
| 125 |
+
|
| 126 |
+
# Verify state has evidence from both sections
|
| 127 |
+
# Note: In real execution, evidence would be synced via WorkflowManager
|
| 128 |
+
# This test verifies the structure works
|
| 129 |
+
|
| 130 |
+
@pytest.mark.asyncio
|
| 131 |
+
async def test_deep_research_synthesizes_final_report(self) -> None:
|
| 132 |
+
"""Test that deep research synthesizes final report from section drafts."""
|
| 133 |
+
flow = DeepResearchFlow(
|
| 134 |
+
max_iterations=1,
|
| 135 |
+
max_time_minutes=2,
|
| 136 |
+
verbose=False,
|
| 137 |
+
use_graph=False,
|
| 138 |
+
use_long_writer=True,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Create report plan
|
| 142 |
+
report_plan = ReportPlan(
|
| 143 |
+
background_context="Test background",
|
| 144 |
+
report_outline=[
|
| 145 |
+
ReportPlanSection(
|
| 146 |
+
title="Introduction",
|
| 147 |
+
key_question="What is the topic?",
|
| 148 |
+
),
|
| 149 |
+
ReportPlanSection(
|
| 150 |
+
title="Conclusion",
|
| 151 |
+
key_question="What are the conclusions?",
|
| 152 |
+
),
|
| 153 |
+
],
|
| 154 |
+
report_title="Test Report",
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Create section drafts
|
| 158 |
+
section_drafts = [
|
| 159 |
+
"# Introduction\n\nThis is the introduction section.",
|
| 160 |
+
"# Conclusion\n\nThis is the conclusion section.",
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
# Mock long writer
|
| 164 |
+
flow.long_writer_agent.write_report = AsyncMock(
|
| 165 |
+
return_value="# Test Report\n\n## Introduction\n\nContent\n\n## Conclusion\n\nContent"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
final_report = await flow._create_final_report("Test query", report_plan, section_drafts)
|
| 169 |
+
|
| 170 |
+
assert isinstance(final_report, str)
|
| 171 |
+
assert "Test Report" in final_report
|
| 172 |
+
# Verify long writer was called with correct parameters
|
| 173 |
+
flow.long_writer_agent.write_report.assert_called_once()
|
| 174 |
+
call_args = flow.long_writer_agent.write_report.call_args
|
| 175 |
+
assert call_args.kwargs["original_query"] == "Test query"
|
| 176 |
+
assert call_args.kwargs["report_title"] == "Test Report"
|
| 177 |
+
assert len(call_args.kwargs["report_draft"].sections) == 2
|
| 178 |
+
|
| 179 |
+
@pytest.mark.asyncio
|
| 180 |
+
async def test_deep_research_agent_chains_full_flow(self) -> None:
|
| 181 |
+
"""Test full deep research flow with agent chains (mocked)."""
|
| 182 |
+
# Initialize workflow state
|
| 183 |
+
init_workflow_state()
|
| 184 |
+
|
| 185 |
+
flow = DeepResearchFlow(
|
| 186 |
+
max_iterations=1,
|
| 187 |
+
max_time_minutes=2,
|
| 188 |
+
verbose=False,
|
| 189 |
+
use_graph=False,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Mock all agents
|
| 193 |
+
mock_plan = ReportPlan(
|
| 194 |
+
background_context="Background",
|
| 195 |
+
report_outline=[
|
| 196 |
+
ReportPlanSection(
|
| 197 |
+
title="Section 1",
|
| 198 |
+
key_question="Question 1?",
|
| 199 |
+
),
|
| 200 |
+
],
|
| 201 |
+
report_title="Test Report",
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
flow.planner_agent.run = AsyncMock(return_value=mock_plan)
|
| 205 |
+
|
| 206 |
+
# Mock iterative research
|
| 207 |
+
async def mock_iterative_run(query: str, **kwargs: dict) -> str:
|
| 208 |
+
return f"# Draft\n\nAnswer to {query}"
|
| 209 |
+
|
| 210 |
+
with patch(
|
| 211 |
+
"src.orchestrator.research_flow.IterativeResearchFlow.run",
|
| 212 |
+
side_effect=mock_iterative_run,
|
| 213 |
+
):
|
| 214 |
+
flow.long_writer_agent.write_report = AsyncMock(
|
| 215 |
+
return_value="# Test Report\n\n## Section 1\n\nDraft content"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Run the full flow
|
| 219 |
+
result = await flow._run_with_chains("Test query")
|
| 220 |
+
|
| 221 |
+
assert isinstance(result, str)
|
| 222 |
+
assert "Test Report" in result
|
| 223 |
+
flow.planner_agent.run.assert_called_once()
|
| 224 |
+
flow.long_writer_agent.write_report.assert_called_once()
|
| 225 |
+
|
| 226 |
+
@pytest.mark.asyncio
|
| 227 |
+
async def test_deep_research_handles_multiple_sections(self) -> None:
|
| 228 |
+
"""Test that deep research handles multiple sections correctly."""
|
| 229 |
+
flow = DeepResearchFlow(
|
| 230 |
+
max_iterations=1,
|
| 231 |
+
max_time_minutes=2,
|
| 232 |
+
verbose=False,
|
| 233 |
+
use_graph=False,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# Create plan with multiple sections
|
| 237 |
+
report_plan = ReportPlan(
|
| 238 |
+
background_context="Background",
|
| 239 |
+
report_outline=[
|
| 240 |
+
ReportPlanSection(
|
| 241 |
+
title=f"Section {i}",
|
| 242 |
+
key_question=f"Question {i}?",
|
| 243 |
+
)
|
| 244 |
+
for i in range(5) # 5 sections
|
| 245 |
+
],
|
| 246 |
+
report_title="Multi-Section Report",
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Mock iterative research to return unique drafts
|
| 250 |
+
async def mock_iterative_run(query: str, **kwargs: dict) -> str:
|
| 251 |
+
section_num = query.split()[-1].replace("?", "")
|
| 252 |
+
return f"# Section {section_num} Draft\n\nContent for section {section_num}"
|
| 253 |
+
|
| 254 |
+
with patch(
|
| 255 |
+
"src.orchestrator.research_flow.IterativeResearchFlow.run",
|
| 256 |
+
side_effect=mock_iterative_run,
|
| 257 |
+
):
|
| 258 |
+
section_drafts = await flow._run_research_loops(report_plan)
|
| 259 |
+
|
| 260 |
+
# Verify all sections were processed
|
| 261 |
+
assert len(section_drafts) == 5
|
| 262 |
+
for i, draft in enumerate(section_drafts):
|
| 263 |
+
assert f"Section {i}" in draft or f"section {i}" in draft.lower()
|
| 264 |
+
|
| 265 |
+
@pytest.mark.asyncio
|
| 266 |
+
async def test_deep_research_workflow_manager_integration(self) -> None:
|
| 267 |
+
"""Test that deep research properly uses WorkflowManager."""
|
| 268 |
+
|
| 269 |
+
# Initialize workflow state
|
| 270 |
+
init_workflow_state()
|
| 271 |
+
|
| 272 |
+
flow = DeepResearchFlow(
|
| 273 |
+
max_iterations=1,
|
| 274 |
+
max_time_minutes=2,
|
| 275 |
+
verbose=False,
|
| 276 |
+
use_graph=False,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# Create report plan
|
| 280 |
+
report_plan = ReportPlan(
|
| 281 |
+
background_context="Background",
|
| 282 |
+
report_outline=[
|
| 283 |
+
ReportPlanSection(
|
| 284 |
+
title="Section 1",
|
| 285 |
+
key_question="Question 1?",
|
| 286 |
+
),
|
| 287 |
+
ReportPlanSection(
|
| 288 |
+
title="Section 2",
|
| 289 |
+
key_question="Question 2?",
|
| 290 |
+
),
|
| 291 |
+
],
|
| 292 |
+
report_title="Test Report",
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Mock iterative research
|
| 296 |
+
async def mock_iterative_run(query: str, **kwargs: dict) -> str:
|
| 297 |
+
return f"# Draft: {query}"
|
| 298 |
+
|
| 299 |
+
with patch(
|
| 300 |
+
"src.orchestrator.research_flow.IterativeResearchFlow.run",
|
| 301 |
+
side_effect=mock_iterative_run,
|
| 302 |
+
):
|
| 303 |
+
section_drafts = await flow._run_research_loops(report_plan)
|
| 304 |
+
|
| 305 |
+
# Verify WorkflowManager was used (section_drafts should be returned)
|
| 306 |
+
assert len(section_drafts) == 2
|
| 307 |
+
# Each draft should be a string
|
| 308 |
+
assert all(isinstance(draft, str) for draft in section_drafts)
|
| 309 |
+
|
| 310 |
+
@pytest.mark.asyncio
|
| 311 |
+
async def test_deep_research_state_initialization(self) -> None:
|
| 312 |
+
"""Test that deep research properly initializes workflow state."""
|
| 313 |
+
flow = DeepResearchFlow(
|
| 314 |
+
max_iterations=1,
|
| 315 |
+
max_time_minutes=2,
|
| 316 |
+
verbose=False,
|
| 317 |
+
use_graph=False,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# Mock the planner
|
| 321 |
+
mock_plan = ReportPlan(
|
| 322 |
+
background_context="Background",
|
| 323 |
+
report_outline=[
|
| 324 |
+
ReportPlanSection(
|
| 325 |
+
title="Section 1",
|
| 326 |
+
key_question="Question 1?",
|
| 327 |
+
),
|
| 328 |
+
],
|
| 329 |
+
report_title="Test Report",
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
flow.planner_agent.run = AsyncMock(return_value=mock_plan)
|
| 333 |
+
|
| 334 |
+
# Mock iterative research
|
| 335 |
+
async def mock_iterative_run(query: str, **kwargs: dict) -> str:
|
| 336 |
+
return "# Draft"
|
| 337 |
+
|
| 338 |
+
with patch(
|
| 339 |
+
"src.orchestrator.research_flow.IterativeResearchFlow.run",
|
| 340 |
+
side_effect=mock_iterative_run,
|
| 341 |
+
):
|
| 342 |
+
flow.long_writer_agent.write_report = AsyncMock(return_value="# Test Report\n\nContent")
|
| 343 |
+
|
| 344 |
+
# Run with chains - should initialize state
|
| 345 |
+
# Note: _run_with_chains handles missing embedding service gracefully
|
| 346 |
+
await flow._run_with_chains("Test query")
|
| 347 |
+
|
| 348 |
+
# Verify state was initialized (get_workflow_state should not raise)
|
| 349 |
+
from src.middleware.state_machine import get_workflow_state
|
| 350 |
+
|
| 351 |
+
state = get_workflow_state()
|
| 352 |
+
assert state is not None
|
tests/integration/test_middleware_integration.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Integration tests for middleware components.
|
| 2 |
+
|
| 3 |
+
Tests the interaction between WorkflowState, WorkflowManager, and BudgetTracker.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
from src.middleware.budget_tracker import BudgetTracker
|
| 9 |
+
from src.middleware.state_machine import init_workflow_state
|
| 10 |
+
from src.middleware.workflow_manager import WorkflowManager
|
| 11 |
+
from src.utils.models import Citation, Evidence
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@pytest.mark.integration
|
| 15 |
+
class TestMiddlewareIntegration:
|
| 16 |
+
"""Integration tests for middleware components."""
|
| 17 |
+
|
| 18 |
+
@pytest.mark.asyncio
|
| 19 |
+
async def test_state_manager_integration(self) -> None:
|
| 20 |
+
"""Test WorkflowState and WorkflowManager integration."""
|
| 21 |
+
# Initialize state
|
| 22 |
+
state = init_workflow_state()
|
| 23 |
+
manager = WorkflowManager()
|
| 24 |
+
|
| 25 |
+
# Create a loop
|
| 26 |
+
loop = await manager.add_loop("test_loop", "Test query")
|
| 27 |
+
|
| 28 |
+
# Add evidence to loop
|
| 29 |
+
ev = Evidence(
|
| 30 |
+
content="Test evidence",
|
| 31 |
+
citation=Citation(
|
| 32 |
+
source="pubmed", title="Test Title", url="https://example.com/1", date="2024-01-01"
|
| 33 |
+
),
|
| 34 |
+
)
|
| 35 |
+
await manager.add_loop_evidence("test_loop", [ev])
|
| 36 |
+
|
| 37 |
+
# Sync to global state
|
| 38 |
+
await manager.sync_loop_evidence_to_state("test_loop")
|
| 39 |
+
|
| 40 |
+
# Verify state has evidence
|
| 41 |
+
assert len(state.evidence) == 1
|
| 42 |
+
assert state.evidence[0].content == "Test evidence"
|
| 43 |
+
|
| 44 |
+
# Verify loop still has evidence
|
| 45 |
+
loop = await manager.get_loop("test_loop")
|
| 46 |
+
assert loop is not None
|
| 47 |
+
assert len(loop.evidence) == 1
|
| 48 |
+
|
| 49 |
+
@pytest.mark.asyncio
|
| 50 |
+
async def test_budget_tracker_with_workflow_manager(self) -> None:
|
| 51 |
+
"""Test BudgetTracker integration with WorkflowManager."""
|
| 52 |
+
manager = WorkflowManager()
|
| 53 |
+
tracker = BudgetTracker()
|
| 54 |
+
|
| 55 |
+
# Create loop and budget
|
| 56 |
+
await manager.add_loop("budget_loop", "Test query")
|
| 57 |
+
tracker.create_budget("budget_loop", tokens_limit=1000, time_limit_seconds=60.0)
|
| 58 |
+
tracker.start_timer("budget_loop")
|
| 59 |
+
|
| 60 |
+
# Simulate some work
|
| 61 |
+
tracker.add_tokens("budget_loop", 500)
|
| 62 |
+
await manager.increment_loop_iteration("budget_loop")
|
| 63 |
+
tracker.increment_iteration("budget_loop")
|
| 64 |
+
|
| 65 |
+
# Check budget
|
| 66 |
+
can_continue = tracker.can_continue("budget_loop")
|
| 67 |
+
assert can_continue is True
|
| 68 |
+
|
| 69 |
+
# Exceed budget
|
| 70 |
+
tracker.add_tokens("budget_loop", 600) # Total: 1100 > 1000
|
| 71 |
+
can_continue = tracker.can_continue("budget_loop")
|
| 72 |
+
assert can_continue is False
|
| 73 |
+
|
| 74 |
+
# Update loop status based on budget
|
| 75 |
+
if not can_continue:
|
| 76 |
+
await manager.update_loop_status("budget_loop", "cancelled")
|
| 77 |
+
|
| 78 |
+
loop = await manager.get_loop("budget_loop")
|
| 79 |
+
assert loop is not None
|
| 80 |
+
assert loop.status == "cancelled"
|
| 81 |
+
|
| 82 |
+
@pytest.mark.asyncio
|
| 83 |
+
async def test_parallel_loops_with_budget_tracking(self) -> None:
|
| 84 |
+
"""Test parallel loops with budget tracking."""
|
| 85 |
+
|
| 86 |
+
async def mock_research_loop(config: dict) -> str:
|
| 87 |
+
"""Mock research loop function."""
|
| 88 |
+
loop_id = config.get("loop_id", "unknown")
|
| 89 |
+
tracker = BudgetTracker()
|
| 90 |
+
manager = WorkflowManager()
|
| 91 |
+
|
| 92 |
+
# Get or create budget
|
| 93 |
+
budget = tracker.get_budget(loop_id)
|
| 94 |
+
if not budget:
|
| 95 |
+
tracker.create_budget(loop_id, tokens_limit=500, time_limit_seconds=10.0)
|
| 96 |
+
tracker.start_timer(loop_id)
|
| 97 |
+
|
| 98 |
+
# Simulate work
|
| 99 |
+
tracker.add_tokens(loop_id, 100)
|
| 100 |
+
await manager.increment_loop_iteration(loop_id)
|
| 101 |
+
tracker.increment_iteration(loop_id)
|
| 102 |
+
|
| 103 |
+
# Check if can continue
|
| 104 |
+
if not tracker.can_continue(loop_id):
|
| 105 |
+
await manager.update_loop_status(loop_id, "cancelled")
|
| 106 |
+
return f"Cancelled: {loop_id}"
|
| 107 |
+
|
| 108 |
+
await manager.update_loop_status(loop_id, "completed")
|
| 109 |
+
return f"Completed: {loop_id}"
|
| 110 |
+
|
| 111 |
+
manager = WorkflowManager()
|
| 112 |
+
tracker = BudgetTracker()
|
| 113 |
+
|
| 114 |
+
# Create budgets for all loops
|
| 115 |
+
configs = [
|
| 116 |
+
{"loop_id": "loop1", "query": "Query 1"},
|
| 117 |
+
{"loop_id": "loop2", "query": "Query 2"},
|
| 118 |
+
{"loop_id": "loop3", "query": "Query 3"},
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
for config in configs:
|
| 122 |
+
loop_id = config["loop_id"]
|
| 123 |
+
await manager.add_loop(loop_id, config["query"])
|
| 124 |
+
tracker.create_budget(loop_id, tokens_limit=500, time_limit_seconds=10.0)
|
| 125 |
+
tracker.start_timer(loop_id)
|
| 126 |
+
|
| 127 |
+
# Run loops in parallel
|
| 128 |
+
results = await manager.run_loops_parallel(configs, mock_research_loop)
|
| 129 |
+
|
| 130 |
+
# Verify all loops completed
|
| 131 |
+
assert len(results) == 3
|
| 132 |
+
for config in configs:
|
| 133 |
+
loop_id = config["loop_id"]
|
| 134 |
+
loop = await manager.get_loop(loop_id)
|
| 135 |
+
assert loop is not None
|
| 136 |
+
assert loop.status in ("completed", "cancelled")
|
| 137 |
+
|
| 138 |
+
@pytest.mark.asyncio
|
| 139 |
+
async def test_state_conversation_integration(self) -> None:
|
| 140 |
+
"""Test WorkflowState conversation integration."""
|
| 141 |
+
state = init_workflow_state()
|
| 142 |
+
|
| 143 |
+
# Add iteration data
|
| 144 |
+
state.conversation.add_iteration()
|
| 145 |
+
state.conversation.set_latest_gap("Knowledge gap 1")
|
| 146 |
+
state.conversation.set_latest_tool_calls(["tool1", "tool2"])
|
| 147 |
+
state.conversation.set_latest_findings(["finding1", "finding2"])
|
| 148 |
+
state.conversation.set_latest_thought("Thought about findings")
|
| 149 |
+
|
| 150 |
+
# Verify conversation history
|
| 151 |
+
assert len(state.conversation.history) == 1
|
| 152 |
+
assert state.conversation.get_latest_gap() == "Knowledge gap 1"
|
| 153 |
+
assert len(state.conversation.get_latest_tool_calls()) == 2
|
| 154 |
+
assert len(state.conversation.get_latest_findings()) == 2
|
| 155 |
+
|
| 156 |
+
# Compile history
|
| 157 |
+
history_str = state.conversation.compile_conversation_history()
|
| 158 |
+
assert "Knowledge gap 1" in history_str
|
| 159 |
+
assert "tool1" in history_str
|
| 160 |
+
assert "finding1" in history_str
|
| 161 |
+
assert "Thought about findings" in history_str
|
| 162 |
+
|
| 163 |
+
@pytest.mark.asyncio
|
| 164 |
+
async def test_multiple_iterations_with_budget(self) -> None:
|
| 165 |
+
"""Test multiple iterations with budget enforcement."""
|
| 166 |
+
manager = WorkflowManager()
|
| 167 |
+
tracker = BudgetTracker()
|
| 168 |
+
|
| 169 |
+
loop_id = "iterative_loop"
|
| 170 |
+
await manager.add_loop(loop_id, "Iterative query")
|
| 171 |
+
tracker.create_budget(loop_id, tokens_limit=1000, iterations_limit=5)
|
| 172 |
+
tracker.start_timer(loop_id)
|
| 173 |
+
|
| 174 |
+
# Simulate multiple iterations
|
| 175 |
+
for _ in range(7): # Try 7 iterations, but limit is 5
|
| 176 |
+
tracker.add_tokens(loop_id, 100)
|
| 177 |
+
await manager.increment_loop_iteration(loop_id)
|
| 178 |
+
tracker.increment_iteration(loop_id)
|
| 179 |
+
|
| 180 |
+
can_continue = tracker.can_continue(loop_id)
|
| 181 |
+
if not can_continue:
|
| 182 |
+
await manager.update_loop_status(loop_id, "cancelled")
|
| 183 |
+
break
|
| 184 |
+
|
| 185 |
+
loop = await manager.get_loop(loop_id)
|
| 186 |
+
assert loop is not None
|
| 187 |
+
# Should be cancelled after 5 iterations
|
| 188 |
+
assert loop.status == "cancelled"
|
| 189 |
+
assert loop.iteration_count == 5
|
| 190 |
+
|
| 191 |
+
@pytest.mark.asyncio
|
| 192 |
+
async def test_evidence_deduplication_across_loops(self) -> None:
|
| 193 |
+
"""Test evidence deduplication when syncing from multiple loops."""
|
| 194 |
+
state = init_workflow_state()
|
| 195 |
+
manager = WorkflowManager()
|
| 196 |
+
|
| 197 |
+
# Create two loops with same evidence
|
| 198 |
+
ev1 = Evidence(
|
| 199 |
+
content="Same content",
|
| 200 |
+
citation=Citation(
|
| 201 |
+
source="pubmed", title="Title", url="https://example.com/1", date="2024"
|
| 202 |
+
),
|
| 203 |
+
)
|
| 204 |
+
ev2 = Evidence(
|
| 205 |
+
content="Different content",
|
| 206 |
+
citation=Citation(
|
| 207 |
+
source="pubmed", title="Title 2", url="https://example.com/2", date="2024"
|
| 208 |
+
),
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Add to loop1
|
| 212 |
+
await manager.add_loop("loop1", "Query 1")
|
| 213 |
+
await manager.add_loop_evidence("loop1", [ev1, ev2])
|
| 214 |
+
await manager.sync_loop_evidence_to_state("loop1")
|
| 215 |
+
|
| 216 |
+
# Add duplicate to loop2
|
| 217 |
+
await manager.add_loop("loop2", "Query 2")
|
| 218 |
+
ev1_duplicate = Evidence(
|
| 219 |
+
content="Same content (duplicate)",
|
| 220 |
+
citation=Citation(
|
| 221 |
+
source="pubmed", title="Title Duplicate", url="https://example.com/1", date="2024"
|
| 222 |
+
),
|
| 223 |
+
)
|
| 224 |
+
await manager.add_loop_evidence("loop2", [ev1_duplicate])
|
| 225 |
+
await manager.sync_loop_evidence_to_state("loop2")
|
| 226 |
+
|
| 227 |
+
# State should have only 2 unique items (deduplicated by URL)
|
| 228 |
+
assert len(state.evidence) == 2
|
| 229 |
+
|
| 230 |
+
@pytest.mark.asyncio
|
| 231 |
+
async def test_global_budget_enforcement(self) -> None:
|
| 232 |
+
"""Test global budget enforcement across all loops."""
|
| 233 |
+
tracker = BudgetTracker()
|
| 234 |
+
tracker.set_global_budget(tokens_limit=2000, time_limit_seconds=60.0)
|
| 235 |
+
|
| 236 |
+
# Simulate multiple loops consuming global budget
|
| 237 |
+
tracker.add_global_tokens(500) # Loop 1
|
| 238 |
+
tracker.add_global_tokens(600) # Loop 2
|
| 239 |
+
tracker.add_global_tokens(700) # Loop 3
|
| 240 |
+
tracker.add_global_tokens(300) # Loop 4 - would exceed
|
| 241 |
+
|
| 242 |
+
global_budget = tracker.get_global_budget()
|
| 243 |
+
assert global_budget is not None
|
| 244 |
+
assert global_budget.tokens_used == 2100
|
| 245 |
+
assert global_budget.is_exceeded() is True
|
tests/integration/test_parallel_loops_judge.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Integration tests for Phase 7: Parallel loops with judge-based completion.
|
| 2 |
+
|
| 3 |
+
These tests verify that WorkflowManager can coordinate parallel research loops
|
| 4 |
+
and use the judge to determine when loops should complete.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
| 8 |
+
|
| 9 |
+
import pytest
|
| 10 |
+
|
| 11 |
+
from src.middleware.workflow_manager import WorkflowManager
|
| 12 |
+
from src.orchestrator.research_flow import IterativeResearchFlow
|
| 13 |
+
from src.utils.models import Citation, Evidence, JudgeAssessment
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@pytest.fixture
|
| 17 |
+
def mock_judge_handler():
|
| 18 |
+
"""Create a mock judge handler."""
|
| 19 |
+
judge = MagicMock()
|
| 20 |
+
judge.assess = AsyncMock()
|
| 21 |
+
return judge
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@pytest.fixture
|
| 25 |
+
def mock_iterative_flow():
|
| 26 |
+
"""Create a mock iterative research flow."""
|
| 27 |
+
flow = MagicMock(spec=IterativeResearchFlow)
|
| 28 |
+
flow.run = AsyncMock(return_value="# Test Report\n\nContent here.")
|
| 29 |
+
return flow
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@pytest.mark.integration
|
| 33 |
+
@pytest.mark.asyncio
|
| 34 |
+
class TestParallelLoopsWithJudge:
|
| 35 |
+
"""Tests for parallel loops with judge-based completion."""
|
| 36 |
+
|
| 37 |
+
async def test_get_loop_evidence(self):
|
| 38 |
+
"""get_loop_evidence should return evidence from a loop."""
|
| 39 |
+
manager = WorkflowManager()
|
| 40 |
+
await manager.add_loop("loop1", "Test query")
|
| 41 |
+
|
| 42 |
+
# Add evidence to the loop
|
| 43 |
+
evidence = [
|
| 44 |
+
Evidence(
|
| 45 |
+
content="Test evidence",
|
| 46 |
+
citation=Citation(
|
| 47 |
+
source="rag", # Use valid SourceName
|
| 48 |
+
title="Test",
|
| 49 |
+
url="https://example.com",
|
| 50 |
+
date="2024-01-01",
|
| 51 |
+
authors=[],
|
| 52 |
+
),
|
| 53 |
+
relevance=0.8,
|
| 54 |
+
)
|
| 55 |
+
]
|
| 56 |
+
await manager.add_loop_evidence("loop1", evidence)
|
| 57 |
+
|
| 58 |
+
# Retrieve evidence
|
| 59 |
+
retrieved_evidence = await manager.get_loop_evidence("loop1")
|
| 60 |
+
assert len(retrieved_evidence) == 1
|
| 61 |
+
assert retrieved_evidence[0].content == "Test evidence"
|
| 62 |
+
|
| 63 |
+
async def test_get_loop_evidence_returns_empty_for_missing_loop(self):
|
| 64 |
+
"""get_loop_evidence should return empty list for non-existent loop."""
|
| 65 |
+
manager = WorkflowManager()
|
| 66 |
+
evidence = await manager.get_loop_evidence("nonexistent")
|
| 67 |
+
assert evidence == []
|
| 68 |
+
|
| 69 |
+
async def test_check_loop_completion_with_sufficient_evidence(self, mock_judge_handler):
|
| 70 |
+
"""check_loop_completion should return True when judge says sufficient."""
|
| 71 |
+
manager = WorkflowManager()
|
| 72 |
+
await manager.add_loop("loop1", "Test query")
|
| 73 |
+
|
| 74 |
+
# Add evidence
|
| 75 |
+
evidence = [
|
| 76 |
+
Evidence(
|
| 77 |
+
content="Comprehensive evidence",
|
| 78 |
+
citation=Citation(
|
| 79 |
+
source="rag", # Use valid SourceName
|
| 80 |
+
title="Test",
|
| 81 |
+
url="https://example.com",
|
| 82 |
+
date="2024-01-01",
|
| 83 |
+
authors=[],
|
| 84 |
+
),
|
| 85 |
+
relevance=0.9,
|
| 86 |
+
)
|
| 87 |
+
]
|
| 88 |
+
await manager.add_loop_evidence("loop1", evidence)
|
| 89 |
+
|
| 90 |
+
# Mock judge to say sufficient
|
| 91 |
+
from src.utils.models import AssessmentDetails
|
| 92 |
+
|
| 93 |
+
mock_judge_handler.assess = AsyncMock(
|
| 94 |
+
return_value=JudgeAssessment(
|
| 95 |
+
details=AssessmentDetails(
|
| 96 |
+
mechanism_score=5,
|
| 97 |
+
mechanism_reasoning="Test mechanism reasoning that is long enough",
|
| 98 |
+
clinical_evidence_score=5,
|
| 99 |
+
clinical_reasoning="Test clinical reasoning that is long enough",
|
| 100 |
+
drug_candidates=[],
|
| 101 |
+
key_findings=[],
|
| 102 |
+
),
|
| 103 |
+
sufficient=True,
|
| 104 |
+
confidence=0.95,
|
| 105 |
+
recommendation="synthesize",
|
| 106 |
+
reasoning="Evidence is sufficient to provide a comprehensive answer.",
|
| 107 |
+
)
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
should_complete, reason = await manager.check_loop_completion(
|
| 111 |
+
"loop1", "Test query", mock_judge_handler
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
assert should_complete is True
|
| 115 |
+
assert "sufficient" in reason.lower() or "judge" in reason.lower()
|
| 116 |
+
assert mock_judge_handler.assess.called
|
| 117 |
+
|
| 118 |
+
async def test_check_loop_completion_with_insufficient_evidence(self, mock_judge_handler):
|
| 119 |
+
"""check_loop_completion should return False when judge says insufficient."""
|
| 120 |
+
manager = WorkflowManager()
|
| 121 |
+
await manager.add_loop("loop1", "Test query")
|
| 122 |
+
|
| 123 |
+
# Add minimal evidence
|
| 124 |
+
evidence = [
|
| 125 |
+
Evidence(
|
| 126 |
+
content="Minimal evidence",
|
| 127 |
+
citation=Citation(
|
| 128 |
+
source="rag", # Use valid SourceName
|
| 129 |
+
title="Test",
|
| 130 |
+
url="https://example.com",
|
| 131 |
+
date="2024-01-01",
|
| 132 |
+
authors=[],
|
| 133 |
+
),
|
| 134 |
+
relevance=0.3,
|
| 135 |
+
)
|
| 136 |
+
]
|
| 137 |
+
await manager.add_loop_evidence("loop1", evidence)
|
| 138 |
+
|
| 139 |
+
# Mock judge to say insufficient
|
| 140 |
+
from src.utils.models import AssessmentDetails
|
| 141 |
+
|
| 142 |
+
mock_judge_handler.assess = AsyncMock(
|
| 143 |
+
return_value=JudgeAssessment(
|
| 144 |
+
details=AssessmentDetails(
|
| 145 |
+
mechanism_score=3,
|
| 146 |
+
mechanism_reasoning="Test mechanism reasoning that is long enough",
|
| 147 |
+
clinical_evidence_score=3,
|
| 148 |
+
clinical_reasoning="Test clinical reasoning that is long enough",
|
| 149 |
+
drug_candidates=[],
|
| 150 |
+
key_findings=[],
|
| 151 |
+
),
|
| 152 |
+
sufficient=False,
|
| 153 |
+
confidence=0.4,
|
| 154 |
+
recommendation="continue",
|
| 155 |
+
reasoning="Need more evidence to provide a comprehensive answer.",
|
| 156 |
+
)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
should_complete, reason = await manager.check_loop_completion(
|
| 160 |
+
"loop1", "Test query", mock_judge_handler
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
assert should_complete is False
|
| 164 |
+
assert "judge" in reason.lower() or "evidence" in reason.lower()
|
| 165 |
+
assert mock_judge_handler.assess.called
|
| 166 |
+
|
| 167 |
+
async def test_check_loop_completion_with_no_evidence(self, mock_judge_handler):
|
| 168 |
+
"""check_loop_completion should return False when no evidence exists."""
|
| 169 |
+
manager = WorkflowManager()
|
| 170 |
+
await manager.add_loop("loop1", "Test query")
|
| 171 |
+
|
| 172 |
+
# Don't add any evidence
|
| 173 |
+
|
| 174 |
+
should_complete, reason = await manager.check_loop_completion(
|
| 175 |
+
"loop1", "Test query", mock_judge_handler
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
assert should_complete is False
|
| 179 |
+
assert "no evidence" in reason.lower() or "not" in reason.lower()
|
| 180 |
+
# Judge should not be called if no evidence
|
| 181 |
+
assert not mock_judge_handler.assess.called
|
| 182 |
+
|
| 183 |
+
async def test_check_loop_completion_handles_judge_error(self, mock_judge_handler):
|
| 184 |
+
"""check_loop_completion should handle judge errors gracefully."""
|
| 185 |
+
manager = WorkflowManager()
|
| 186 |
+
await manager.add_loop("loop1", "Test query")
|
| 187 |
+
|
| 188 |
+
evidence = [
|
| 189 |
+
Evidence(
|
| 190 |
+
content="Test evidence",
|
| 191 |
+
citation=Citation(
|
| 192 |
+
source="rag", # Use valid SourceName
|
| 193 |
+
title="Test",
|
| 194 |
+
url="https://example.com",
|
| 195 |
+
date="2024-01-01",
|
| 196 |
+
authors=[],
|
| 197 |
+
),
|
| 198 |
+
relevance=0.8,
|
| 199 |
+
)
|
| 200 |
+
]
|
| 201 |
+
await manager.add_loop_evidence("loop1", evidence)
|
| 202 |
+
|
| 203 |
+
# Mock judge to raise error
|
| 204 |
+
mock_judge_handler.assess = AsyncMock(side_effect=Exception("Judge error"))
|
| 205 |
+
|
| 206 |
+
should_complete, reason = await manager.check_loop_completion(
|
| 207 |
+
"loop1", "Test query", mock_judge_handler
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
assert should_complete is False
|
| 211 |
+
assert "error" in reason.lower() or "failed" in reason.lower()
|
| 212 |
+
|
| 213 |
+
async def test_parallel_loops_with_judge_early_termination(
|
| 214 |
+
self, mock_judge_handler, mock_iterative_flow
|
| 215 |
+
):
|
| 216 |
+
"""Parallel loops should terminate early when judge says sufficient."""
|
| 217 |
+
manager = WorkflowManager()
|
| 218 |
+
|
| 219 |
+
# Create multiple loops
|
| 220 |
+
loop_configs = [
|
| 221 |
+
{"loop_id": "loop1", "query": "Query 1"},
|
| 222 |
+
{"loop_id": "loop2", "query": "Query 2"},
|
| 223 |
+
]
|
| 224 |
+
|
| 225 |
+
# Define loop function that extracts loop_func from config if needed
|
| 226 |
+
async def loop_func(config: dict) -> str:
|
| 227 |
+
return await mock_iterative_flow.run(config.get("query", ""))
|
| 228 |
+
|
| 229 |
+
# Add evidence to loop1 that will trigger early completion
|
| 230 |
+
await manager.add_loop("loop1", "Query 1")
|
| 231 |
+
evidence = [
|
| 232 |
+
Evidence(
|
| 233 |
+
content="Comprehensive evidence for query 1",
|
| 234 |
+
citation=Citation(
|
| 235 |
+
source="rag", # Use valid SourceName
|
| 236 |
+
title="Test",
|
| 237 |
+
url="https://example.com",
|
| 238 |
+
date="2024-01-01",
|
| 239 |
+
authors=[],
|
| 240 |
+
),
|
| 241 |
+
relevance=0.95,
|
| 242 |
+
)
|
| 243 |
+
]
|
| 244 |
+
await manager.add_loop_evidence("loop1", evidence)
|
| 245 |
+
|
| 246 |
+
# Mock judge to say sufficient for loop1
|
| 247 |
+
call_count = {"count": 0}
|
| 248 |
+
|
| 249 |
+
def mock_assess(query: str, evidence_list: list[Evidence]) -> JudgeAssessment:
|
| 250 |
+
from src.utils.models import AssessmentDetails
|
| 251 |
+
|
| 252 |
+
call_count["count"] += 1
|
| 253 |
+
if "Query 1" in query or len(evidence_list) > 0:
|
| 254 |
+
return JudgeAssessment(
|
| 255 |
+
details=AssessmentDetails(
|
| 256 |
+
mechanism_score=5,
|
| 257 |
+
mechanism_reasoning="Test mechanism reasoning that is long enough",
|
| 258 |
+
clinical_evidence_score=5,
|
| 259 |
+
clinical_reasoning="Test clinical reasoning that is long enough",
|
| 260 |
+
drug_candidates=[],
|
| 261 |
+
key_findings=[],
|
| 262 |
+
),
|
| 263 |
+
sufficient=True,
|
| 264 |
+
confidence=0.95,
|
| 265 |
+
recommendation="synthesize",
|
| 266 |
+
reasoning="Sufficient evidence has been collected to answer the query.",
|
| 267 |
+
)
|
| 268 |
+
return JudgeAssessment(
|
| 269 |
+
details=AssessmentDetails(
|
| 270 |
+
mechanism_score=3,
|
| 271 |
+
mechanism_reasoning="Test mechanism reasoning that is long enough",
|
| 272 |
+
clinical_evidence_score=3,
|
| 273 |
+
clinical_reasoning="Test clinical reasoning that is long enough",
|
| 274 |
+
drug_candidates=[],
|
| 275 |
+
key_findings=[],
|
| 276 |
+
),
|
| 277 |
+
sufficient=False,
|
| 278 |
+
confidence=0.5,
|
| 279 |
+
recommendation="continue",
|
| 280 |
+
reasoning="Need more evidence to provide a comprehensive answer.",
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
mock_judge_handler.assess = AsyncMock(side_effect=mock_assess)
|
| 284 |
+
|
| 285 |
+
# Run loops in parallel
|
| 286 |
+
with patch("src.middleware.workflow_manager.get_workflow_state") as mock_state:
|
| 287 |
+
mock_state_obj = MagicMock()
|
| 288 |
+
mock_state_obj.evidence = []
|
| 289 |
+
mock_state.return_value = mock_state_obj
|
| 290 |
+
|
| 291 |
+
results = await manager.run_loops_parallel(
|
| 292 |
+
loop_configs, loop_func=loop_func, judge_handler=mock_judge_handler
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Both loops should complete
|
| 296 |
+
assert len(results) == 2
|
| 297 |
+
assert all(isinstance(r, str) for r in results)
|
| 298 |
+
|
| 299 |
+
async def test_parallel_loops_aggregate_evidence(self, mock_judge_handler):
|
| 300 |
+
"""Parallel loops should aggregate evidence from all loops."""
|
| 301 |
+
manager = WorkflowManager()
|
| 302 |
+
|
| 303 |
+
# Create loops
|
| 304 |
+
await manager.add_loop("loop1", "Query 1")
|
| 305 |
+
await manager.add_loop("loop2", "Query 2")
|
| 306 |
+
|
| 307 |
+
# Add evidence to each loop
|
| 308 |
+
evidence1 = [
|
| 309 |
+
Evidence(
|
| 310 |
+
content="Evidence from loop 1",
|
| 311 |
+
citation=Citation(
|
| 312 |
+
source="rag", # Use valid SourceName
|
| 313 |
+
title="Test 1",
|
| 314 |
+
url="https://example.com/1",
|
| 315 |
+
date="2024-01-01",
|
| 316 |
+
authors=[],
|
| 317 |
+
),
|
| 318 |
+
relevance=0.8,
|
| 319 |
+
)
|
| 320 |
+
]
|
| 321 |
+
evidence2 = [
|
| 322 |
+
Evidence(
|
| 323 |
+
content="Evidence from loop 2",
|
| 324 |
+
citation=Citation(
|
| 325 |
+
source="rag", # Use valid SourceName
|
| 326 |
+
title="Test 2",
|
| 327 |
+
url="https://example.com/2",
|
| 328 |
+
date="2024-01-01",
|
| 329 |
+
authors=[],
|
| 330 |
+
),
|
| 331 |
+
relevance=0.9,
|
| 332 |
+
)
|
| 333 |
+
]
|
| 334 |
+
|
| 335 |
+
await manager.add_loop_evidence("loop1", evidence1)
|
| 336 |
+
await manager.add_loop_evidence("loop2", evidence2)
|
| 337 |
+
|
| 338 |
+
# Get evidence from both loops
|
| 339 |
+
evidence1_retrieved = await manager.get_loop_evidence("loop1")
|
| 340 |
+
evidence2_retrieved = await manager.get_loop_evidence("loop2")
|
| 341 |
+
|
| 342 |
+
assert len(evidence1_retrieved) == 1
|
| 343 |
+
assert len(evidence2_retrieved) == 1
|
| 344 |
+
assert evidence1_retrieved[0].content == "Evidence from loop 1"
|
| 345 |
+
assert evidence2_retrieved[0].content == "Evidence from loop 2"
|
| 346 |
+
|
| 347 |
+
async def test_loop_status_updated_on_completion(self, mock_judge_handler):
|
| 348 |
+
"""Loop status should be updated when judge determines completion."""
|
| 349 |
+
manager = WorkflowManager()
|
| 350 |
+
await manager.add_loop("loop1", "Test query")
|
| 351 |
+
|
| 352 |
+
# Add sufficient evidence
|
| 353 |
+
evidence = [
|
| 354 |
+
Evidence(
|
| 355 |
+
content="Sufficient evidence",
|
| 356 |
+
citation=Citation(
|
| 357 |
+
source="rag", # Use valid SourceName
|
| 358 |
+
title="Test",
|
| 359 |
+
url="https://example.com",
|
| 360 |
+
date="2024-01-01",
|
| 361 |
+
authors=[],
|
| 362 |
+
),
|
| 363 |
+
relevance=0.95,
|
| 364 |
+
)
|
| 365 |
+
]
|
| 366 |
+
await manager.add_loop_evidence("loop1", evidence)
|
| 367 |
+
|
| 368 |
+
from src.utils.models import AssessmentDetails
|
| 369 |
+
|
| 370 |
+
mock_judge_handler.assess = AsyncMock(
|
| 371 |
+
return_value=JudgeAssessment(
|
| 372 |
+
details=AssessmentDetails(
|
| 373 |
+
mechanism_score=5,
|
| 374 |
+
mechanism_reasoning="Test mechanism reasoning that is long enough",
|
| 375 |
+
clinical_evidence_score=5,
|
| 376 |
+
clinical_reasoning="Test clinical reasoning that is long enough",
|
| 377 |
+
drug_candidates=[],
|
| 378 |
+
key_findings=[],
|
| 379 |
+
),
|
| 380 |
+
sufficient=True,
|
| 381 |
+
confidence=0.95,
|
| 382 |
+
recommendation="synthesize",
|
| 383 |
+
reasoning="Complete evidence has been collected to answer the query.",
|
| 384 |
+
)
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# Check completion (this should update status internally if implemented)
|
| 388 |
+
should_complete, _ = await manager.check_loop_completion(
|
| 389 |
+
"loop1", "Test query", mock_judge_handler
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
assert should_complete is True
|
| 393 |
+
# Status update would happen in run_loops_parallel, not in check_loop_completion
|
| 394 |
+
loop = await manager.get_loop("loop1")
|
| 395 |
+
assert loop is not None
|
| 396 |
+
# Status might still be "pending" or "running" until run_loops_parallel updates it
|
tests/integration/test_rag_integration.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Integration tests for RAG integration.
|
| 2 |
+
|
| 3 |
+
These tests require OPENAI_API_KEY and may make real API calls.
|
| 4 |
+
Marked with @pytest.mark.integration to skip in unit test runs.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from src.services.llamaindex_rag import get_rag_service
|
| 10 |
+
from src.tools.rag_tool import create_rag_tool
|
| 11 |
+
from src.tools.search_handler import SearchHandler
|
| 12 |
+
from src.tools.tool_executor import execute_agent_task
|
| 13 |
+
from src.utils.config import settings
|
| 14 |
+
from src.utils.models import AgentTask, Citation, Evidence
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@pytest.mark.integration
|
| 18 |
+
class TestRAGServiceIntegration:
|
| 19 |
+
"""Integration tests for LlamaIndexRAGService."""
|
| 20 |
+
|
| 21 |
+
@pytest.mark.asyncio
|
| 22 |
+
async def test_rag_service_ingest_and_retrieve(self):
|
| 23 |
+
"""RAG service should ingest and retrieve evidence."""
|
| 24 |
+
if not settings.openai_api_key:
|
| 25 |
+
pytest.skip("OPENAI_API_KEY required for RAG integration tests")
|
| 26 |
+
|
| 27 |
+
# Create RAG service
|
| 28 |
+
rag_service = get_rag_service(collection_name="test_integration")
|
| 29 |
+
|
| 30 |
+
# Create sample evidence
|
| 31 |
+
evidence_list = [
|
| 32 |
+
Evidence(
|
| 33 |
+
content="Metformin is a first-line treatment for type 2 diabetes. It works by reducing glucose production in the liver and improving insulin sensitivity.",
|
| 34 |
+
citation=Citation(
|
| 35 |
+
source="pubmed",
|
| 36 |
+
title="Metformin Mechanism of Action",
|
| 37 |
+
url="https://pubmed.ncbi.nlm.nih.gov/12345678/",
|
| 38 |
+
date="2024-01-15",
|
| 39 |
+
authors=["Smith J", "Johnson M"],
|
| 40 |
+
),
|
| 41 |
+
relevance=0.9,
|
| 42 |
+
),
|
| 43 |
+
Evidence(
|
| 44 |
+
content="Recent studies suggest metformin may have neuroprotective effects in Alzheimer's disease models.",
|
| 45 |
+
citation=Citation(
|
| 46 |
+
source="pubmed",
|
| 47 |
+
title="Metformin and Neuroprotection",
|
| 48 |
+
url="https://pubmed.ncbi.nlm.nih.gov/12345679/",
|
| 49 |
+
date="2024-02-20",
|
| 50 |
+
authors=["Brown K", "Davis L"],
|
| 51 |
+
),
|
| 52 |
+
relevance=0.85,
|
| 53 |
+
),
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
# Ingest evidence
|
| 57 |
+
rag_service.ingest_evidence(evidence_list)
|
| 58 |
+
|
| 59 |
+
# Retrieve evidence
|
| 60 |
+
results = rag_service.retrieve("metformin diabetes", top_k=2)
|
| 61 |
+
|
| 62 |
+
# Assert
|
| 63 |
+
assert len(results) > 0
|
| 64 |
+
assert any("metformin" in r["text"].lower() for r in results)
|
| 65 |
+
assert all("text" in r for r in results)
|
| 66 |
+
assert all("metadata" in r for r in results)
|
| 67 |
+
|
| 68 |
+
# Cleanup
|
| 69 |
+
rag_service.clear_collection()
|
| 70 |
+
|
| 71 |
+
@pytest.mark.asyncio
|
| 72 |
+
async def test_rag_service_query(self):
|
| 73 |
+
"""RAG service should synthesize responses from ingested evidence."""
|
| 74 |
+
if not settings.openai_api_key:
|
| 75 |
+
pytest.skip("OPENAI_API_KEY required for RAG integration tests")
|
| 76 |
+
|
| 77 |
+
rag_service = get_rag_service(collection_name="test_query")
|
| 78 |
+
|
| 79 |
+
# Ingest evidence
|
| 80 |
+
evidence_list = [
|
| 81 |
+
Evidence(
|
| 82 |
+
content="Python is a high-level programming language known for its simplicity and readability.",
|
| 83 |
+
citation=Citation(
|
| 84 |
+
source="pubmed",
|
| 85 |
+
title="Python Programming",
|
| 86 |
+
url="https://example.com/python",
|
| 87 |
+
date="2024",
|
| 88 |
+
authors=["Author"],
|
| 89 |
+
),
|
| 90 |
+
)
|
| 91 |
+
]
|
| 92 |
+
rag_service.ingest_evidence(evidence_list)
|
| 93 |
+
|
| 94 |
+
# Query
|
| 95 |
+
response = rag_service.query("What is Python?", top_k=1)
|
| 96 |
+
|
| 97 |
+
assert isinstance(response, str)
|
| 98 |
+
assert len(response) > 0
|
| 99 |
+
assert "python" in response.lower()
|
| 100 |
+
|
| 101 |
+
# Cleanup
|
| 102 |
+
rag_service.clear_collection()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@pytest.mark.integration
|
| 106 |
+
class TestRAGToolIntegration:
|
| 107 |
+
"""Integration tests for RAGTool."""
|
| 108 |
+
|
| 109 |
+
@pytest.mark.asyncio
|
| 110 |
+
async def test_rag_tool_search(self):
|
| 111 |
+
"""RAGTool should search RAG service and return Evidence objects."""
|
| 112 |
+
if not settings.openai_api_key:
|
| 113 |
+
pytest.skip("OPENAI_API_KEY required for RAG integration tests")
|
| 114 |
+
|
| 115 |
+
# Create RAG service and ingest evidence
|
| 116 |
+
rag_service = get_rag_service(collection_name="test_rag_tool")
|
| 117 |
+
evidence_list = [
|
| 118 |
+
Evidence(
|
| 119 |
+
content="Machine learning is a subset of artificial intelligence.",
|
| 120 |
+
citation=Citation(
|
| 121 |
+
source="pubmed",
|
| 122 |
+
title="ML Basics",
|
| 123 |
+
url="https://example.com/ml",
|
| 124 |
+
date="2024",
|
| 125 |
+
authors=["ML Expert"],
|
| 126 |
+
),
|
| 127 |
+
)
|
| 128 |
+
]
|
| 129 |
+
rag_service.ingest_evidence(evidence_list)
|
| 130 |
+
|
| 131 |
+
# Create RAG tool
|
| 132 |
+
tool = create_rag_tool(rag_service=rag_service)
|
| 133 |
+
|
| 134 |
+
# Search
|
| 135 |
+
results = await tool.search("machine learning", max_results=5)
|
| 136 |
+
|
| 137 |
+
# Assert
|
| 138 |
+
assert len(results) > 0
|
| 139 |
+
assert all(isinstance(e, Evidence) for e in results)
|
| 140 |
+
assert results[0].citation.source == "rag"
|
| 141 |
+
assert (
|
| 142 |
+
"machine learning" in results[0].content.lower()
|
| 143 |
+
or "artificial intelligence" in results[0].content.lower()
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# Cleanup
|
| 147 |
+
rag_service.clear_collection()
|
| 148 |
+
|
| 149 |
+
@pytest.mark.asyncio
|
| 150 |
+
async def test_rag_tool_empty_collection(self):
|
| 151 |
+
"""RAGTool should return empty list when collection is empty."""
|
| 152 |
+
if not settings.openai_api_key:
|
| 153 |
+
pytest.skip("OPENAI_API_KEY required for RAG integration tests")
|
| 154 |
+
|
| 155 |
+
rag_service = get_rag_service(collection_name="test_empty")
|
| 156 |
+
rag_service.clear_collection() # Ensure empty
|
| 157 |
+
|
| 158 |
+
tool = create_rag_tool(rag_service=rag_service)
|
| 159 |
+
results = await tool.search("any query")
|
| 160 |
+
|
| 161 |
+
assert results == []
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@pytest.mark.integration
|
| 165 |
+
class TestRAGAgentIntegration:
|
| 166 |
+
"""Integration tests for RAGAgent in tool executor."""
|
| 167 |
+
|
| 168 |
+
@pytest.mark.asyncio
|
| 169 |
+
async def test_rag_agent_execution(self):
|
| 170 |
+
"""RAGAgent should execute and return ToolAgentOutput."""
|
| 171 |
+
if not settings.openai_api_key:
|
| 172 |
+
pytest.skip("OPENAI_API_KEY required for RAG integration tests")
|
| 173 |
+
|
| 174 |
+
# Setup: Ingest evidence into RAG
|
| 175 |
+
rag_service = get_rag_service(collection_name="test_rag_agent")
|
| 176 |
+
evidence_list = [
|
| 177 |
+
Evidence(
|
| 178 |
+
content="Deep learning uses neural networks with multiple layers.",
|
| 179 |
+
citation=Citation(
|
| 180 |
+
source="pubmed",
|
| 181 |
+
title="Deep Learning",
|
| 182 |
+
url="https://example.com/dl",
|
| 183 |
+
date="2024",
|
| 184 |
+
authors=["DL Researcher"],
|
| 185 |
+
),
|
| 186 |
+
)
|
| 187 |
+
]
|
| 188 |
+
rag_service.ingest_evidence(evidence_list)
|
| 189 |
+
|
| 190 |
+
# Execute RAGAgent task
|
| 191 |
+
task = AgentTask(
|
| 192 |
+
agent="RAGAgent",
|
| 193 |
+
query="deep learning",
|
| 194 |
+
gap="Need information about deep learning",
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
result = await execute_agent_task(task)
|
| 198 |
+
|
| 199 |
+
# Assert
|
| 200 |
+
assert result.output
|
| 201 |
+
assert "deep learning" in result.output.lower() or "neural network" in result.output.lower()
|
| 202 |
+
assert len(result.sources) > 0
|
| 203 |
+
|
| 204 |
+
# Cleanup
|
| 205 |
+
rag_service.clear_collection()
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
@pytest.mark.integration
|
| 209 |
+
class TestRAGSearchHandlerIntegration:
|
| 210 |
+
"""Integration tests for RAG in SearchHandler."""
|
| 211 |
+
|
| 212 |
+
@pytest.mark.asyncio
|
| 213 |
+
async def test_search_handler_with_rag(self):
|
| 214 |
+
"""SearchHandler should work with RAG tool included."""
|
| 215 |
+
if not settings.openai_api_key:
|
| 216 |
+
pytest.skip("OPENAI_API_KEY required for RAG integration tests")
|
| 217 |
+
|
| 218 |
+
# Setup: Create RAG service and ingest some evidence
|
| 219 |
+
rag_service = get_rag_service(collection_name="test_search_handler")
|
| 220 |
+
evidence_list = [
|
| 221 |
+
Evidence(
|
| 222 |
+
content="Test evidence for search handler integration.",
|
| 223 |
+
citation=Citation(
|
| 224 |
+
source="pubmed",
|
| 225 |
+
title="Test Evidence",
|
| 226 |
+
url="https://example.com/test",
|
| 227 |
+
date="2024",
|
| 228 |
+
authors=["Tester"],
|
| 229 |
+
),
|
| 230 |
+
)
|
| 231 |
+
]
|
| 232 |
+
rag_service.ingest_evidence(evidence_list)
|
| 233 |
+
|
| 234 |
+
# Create SearchHandler with RAG
|
| 235 |
+
handler = SearchHandler(
|
| 236 |
+
tools=[], # No other tools
|
| 237 |
+
include_rag=True,
|
| 238 |
+
auto_ingest_to_rag=False, # Don't auto-ingest (already has data)
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# Execute search
|
| 242 |
+
result = await handler.execute("test evidence", max_results_per_tool=5)
|
| 243 |
+
|
| 244 |
+
# Assert
|
| 245 |
+
assert result.total_found > 0
|
| 246 |
+
assert "rag" in result.sources_searched
|
| 247 |
+
assert any(e.citation.source == "rag" for e in result.evidence)
|
| 248 |
+
|
| 249 |
+
# Cleanup
|
| 250 |
+
rag_service.clear_collection()
|
| 251 |
+
|
| 252 |
+
@pytest.mark.asyncio
|
| 253 |
+
async def test_search_handler_auto_ingest(self):
|
| 254 |
+
"""SearchHandler should auto-ingest evidence into RAG."""
|
| 255 |
+
if not settings.openai_api_key:
|
| 256 |
+
pytest.skip("OPENAI_API_KEY required for RAG integration tests")
|
| 257 |
+
|
| 258 |
+
# Create empty RAG service
|
| 259 |
+
rag_service = get_rag_service(collection_name="test_auto_ingest")
|
| 260 |
+
rag_service.clear_collection()
|
| 261 |
+
|
| 262 |
+
# Create mock tool that returns evidence
|
| 263 |
+
from unittest.mock import AsyncMock
|
| 264 |
+
|
| 265 |
+
mock_tool = AsyncMock()
|
| 266 |
+
mock_tool.name = "pubmed"
|
| 267 |
+
mock_tool.search = AsyncMock(
|
| 268 |
+
return_value=[
|
| 269 |
+
Evidence(
|
| 270 |
+
content="Evidence to be ingested",
|
| 271 |
+
citation=Citation(
|
| 272 |
+
source="pubmed",
|
| 273 |
+
title="Test",
|
| 274 |
+
url="https://example.com",
|
| 275 |
+
date="2024",
|
| 276 |
+
authors=[],
|
| 277 |
+
),
|
| 278 |
+
)
|
| 279 |
+
]
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Create handler with auto-ingest enabled
|
| 283 |
+
handler = SearchHandler(
|
| 284 |
+
tools=[mock_tool],
|
| 285 |
+
include_rag=False, # Don't include RAG as search tool
|
| 286 |
+
auto_ingest_to_rag=True,
|
| 287 |
+
)
|
| 288 |
+
handler._rag_service = rag_service # Inject RAG service
|
| 289 |
+
|
| 290 |
+
# Execute search
|
| 291 |
+
await handler.execute("test query")
|
| 292 |
+
|
| 293 |
+
# Verify evidence was ingested
|
| 294 |
+
rag_results = rag_service.retrieve("Evidence to be ingested", top_k=1)
|
| 295 |
+
assert len(rag_results) > 0
|
| 296 |
+
|
| 297 |
+
# Cleanup
|
| 298 |
+
rag_service.clear_collection()
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
@pytest.mark.integration
|
| 302 |
+
class TestRAGHybridSearchIntegration:
|
| 303 |
+
"""Integration tests for hybrid search (RAG + database)."""
|
| 304 |
+
|
| 305 |
+
@pytest.mark.asyncio
|
| 306 |
+
async def test_hybrid_search_rag_and_pubmed(self):
|
| 307 |
+
"""SearchHandler should support RAG + PubMed hybrid search."""
|
| 308 |
+
if not settings.openai_api_key:
|
| 309 |
+
pytest.skip("OPENAI_API_KEY required for RAG integration tests")
|
| 310 |
+
|
| 311 |
+
# Setup: Ingest evidence into RAG
|
| 312 |
+
rag_service = get_rag_service(collection_name="test_hybrid")
|
| 313 |
+
evidence_list = [
|
| 314 |
+
Evidence(
|
| 315 |
+
content="Previously collected evidence about metformin.",
|
| 316 |
+
citation=Citation(
|
| 317 |
+
source="pubmed",
|
| 318 |
+
title="Previous Research",
|
| 319 |
+
url="https://example.com/prev",
|
| 320 |
+
date="2024",
|
| 321 |
+
authors=[],
|
| 322 |
+
),
|
| 323 |
+
)
|
| 324 |
+
]
|
| 325 |
+
rag_service.ingest_evidence(evidence_list)
|
| 326 |
+
|
| 327 |
+
# Note: This test would require real PubMed API access
|
| 328 |
+
# For now, we'll just test that the handler can be created with both tools
|
| 329 |
+
from src.tools.pubmed import PubMedTool
|
| 330 |
+
|
| 331 |
+
handler = SearchHandler(
|
| 332 |
+
tools=[PubMedTool()],
|
| 333 |
+
include_rag=True,
|
| 334 |
+
auto_ingest_to_rag=True,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Verify handler has both tools
|
| 338 |
+
tool_names = [t.name for t in handler.tools]
|
| 339 |
+
assert "pubmed" in tool_names
|
| 340 |
+
assert "rag" in tool_names
|
| 341 |
+
|
| 342 |
+
# Cleanup
|
| 343 |
+
rag_service.clear_collection()
|
tests/integration/test_research_flows.py
ADDED
|
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Integration tests for research flows.
|
| 2 |
+
|
| 3 |
+
These tests require API keys and may make real API calls.
|
| 4 |
+
Marked with @pytest.mark.integration to skip in unit test runs.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from src.agent_factory.agents import (
|
| 10 |
+
create_deep_flow,
|
| 11 |
+
create_iterative_flow,
|
| 12 |
+
create_planner_agent,
|
| 13 |
+
)
|
| 14 |
+
from src.orchestrator.graph_orchestrator import create_graph_orchestrator
|
| 15 |
+
from src.utils.config import settings
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@pytest.mark.integration
|
| 19 |
+
class TestPlannerAgentIntegration:
|
| 20 |
+
"""Integration tests for PlannerAgent with real API calls."""
|
| 21 |
+
|
| 22 |
+
@pytest.mark.asyncio
|
| 23 |
+
async def test_planner_agent_creates_plan(self):
|
| 24 |
+
"""PlannerAgent should create a valid report plan with real API."""
|
| 25 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 26 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 27 |
+
|
| 28 |
+
planner = create_planner_agent()
|
| 29 |
+
result = await planner.run("What are the main features of Python programming language?")
|
| 30 |
+
|
| 31 |
+
assert result.report_title
|
| 32 |
+
assert len(result.report_outline) > 0
|
| 33 |
+
assert result.report_outline[0].title
|
| 34 |
+
assert result.report_outline[0].key_question
|
| 35 |
+
|
| 36 |
+
@pytest.mark.asyncio
|
| 37 |
+
async def test_planner_agent_includes_background_context(self):
|
| 38 |
+
"""PlannerAgent should include background context in plan."""
|
| 39 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 40 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 41 |
+
|
| 42 |
+
planner = create_planner_agent()
|
| 43 |
+
result = await planner.run("Explain quantum computing basics")
|
| 44 |
+
|
| 45 |
+
assert result.background_context
|
| 46 |
+
assert len(result.background_context) > 50 # Should have substantial context
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@pytest.mark.integration
|
| 50 |
+
class TestIterativeResearchFlowIntegration:
|
| 51 |
+
"""Integration tests for IterativeResearchFlow with real API calls."""
|
| 52 |
+
|
| 53 |
+
@pytest.mark.asyncio
|
| 54 |
+
async def test_iterative_flow_completes_simple_query(self):
|
| 55 |
+
"""IterativeResearchFlow should complete a simple research query."""
|
| 56 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 57 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 58 |
+
|
| 59 |
+
flow = create_iterative_flow(max_iterations=2, max_time_minutes=2)
|
| 60 |
+
result = await flow.run(
|
| 61 |
+
query="What is the capital of France?",
|
| 62 |
+
output_length="A short paragraph",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
assert isinstance(result, str)
|
| 66 |
+
assert len(result) > 0
|
| 67 |
+
# Should mention Paris
|
| 68 |
+
assert "paris" in result.lower() or "france" in result.lower()
|
| 69 |
+
|
| 70 |
+
@pytest.mark.asyncio
|
| 71 |
+
async def test_iterative_flow_respects_max_iterations(self):
|
| 72 |
+
"""IterativeResearchFlow should respect max_iterations limit."""
|
| 73 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 74 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 75 |
+
|
| 76 |
+
flow = create_iterative_flow(max_iterations=1, max_time_minutes=5)
|
| 77 |
+
result = await flow.run(query="What are the main features of Python?")
|
| 78 |
+
|
| 79 |
+
assert isinstance(result, str)
|
| 80 |
+
# Should complete within 1 iteration (or hit max)
|
| 81 |
+
assert flow.iteration <= 1
|
| 82 |
+
|
| 83 |
+
@pytest.mark.asyncio
|
| 84 |
+
async def test_iterative_flow_with_background_context(self):
|
| 85 |
+
"""IterativeResearchFlow should use background context."""
|
| 86 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 87 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 88 |
+
|
| 89 |
+
flow = create_iterative_flow(max_iterations=2, max_time_minutes=2)
|
| 90 |
+
result = await flow.run(
|
| 91 |
+
query="What is machine learning?",
|
| 92 |
+
background_context="Machine learning is a subset of artificial intelligence.",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
assert isinstance(result, str)
|
| 96 |
+
assert len(result) > 0
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@pytest.mark.integration
|
| 100 |
+
class TestDeepResearchFlowIntegration:
|
| 101 |
+
"""Integration tests for DeepResearchFlow with real API calls."""
|
| 102 |
+
|
| 103 |
+
@pytest.mark.asyncio
|
| 104 |
+
async def test_deep_flow_creates_multi_section_report(self):
|
| 105 |
+
"""DeepResearchFlow should create a report with multiple sections."""
|
| 106 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 107 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 108 |
+
|
| 109 |
+
flow = create_deep_flow(
|
| 110 |
+
max_iterations=1, # Keep it short for testing
|
| 111 |
+
max_time_minutes=3,
|
| 112 |
+
)
|
| 113 |
+
result = await flow.run("What are the main features of Python programming language?")
|
| 114 |
+
|
| 115 |
+
assert isinstance(result, str)
|
| 116 |
+
assert len(result) > 100 # Should have substantial content
|
| 117 |
+
# Should have section structure
|
| 118 |
+
assert "#" in result or "##" in result
|
| 119 |
+
|
| 120 |
+
@pytest.mark.asyncio
|
| 121 |
+
async def test_deep_flow_uses_long_writer(self):
|
| 122 |
+
"""DeepResearchFlow should use long writer by default."""
|
| 123 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 124 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 125 |
+
|
| 126 |
+
flow = create_deep_flow(
|
| 127 |
+
max_iterations=1,
|
| 128 |
+
max_time_minutes=3,
|
| 129 |
+
use_long_writer=True,
|
| 130 |
+
)
|
| 131 |
+
result = await flow.run("Explain the basics of quantum computing")
|
| 132 |
+
|
| 133 |
+
assert isinstance(result, str)
|
| 134 |
+
assert len(result) > 0
|
| 135 |
+
|
| 136 |
+
@pytest.mark.asyncio
|
| 137 |
+
async def test_deep_flow_uses_proofreader_when_specified(self):
|
| 138 |
+
"""DeepResearchFlow should use proofreader when use_long_writer=False."""
|
| 139 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 140 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 141 |
+
|
| 142 |
+
flow = create_deep_flow(
|
| 143 |
+
max_iterations=1,
|
| 144 |
+
max_time_minutes=3,
|
| 145 |
+
use_long_writer=False,
|
| 146 |
+
)
|
| 147 |
+
result = await flow.run("What is artificial intelligence?")
|
| 148 |
+
|
| 149 |
+
assert isinstance(result, str)
|
| 150 |
+
assert len(result) > 0
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@pytest.mark.integration
|
| 154 |
+
class TestGraphOrchestratorIntegration:
|
| 155 |
+
"""Integration tests for GraphOrchestrator with real API calls."""
|
| 156 |
+
|
| 157 |
+
@pytest.mark.asyncio
|
| 158 |
+
async def test_graph_orchestrator_iterative_mode(self):
|
| 159 |
+
"""GraphOrchestrator should run in iterative mode."""
|
| 160 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 161 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 162 |
+
|
| 163 |
+
orchestrator = create_graph_orchestrator(
|
| 164 |
+
mode="iterative",
|
| 165 |
+
max_iterations=1,
|
| 166 |
+
max_time_minutes=2,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
events = []
|
| 170 |
+
async for event in orchestrator.run("What is Python?"):
|
| 171 |
+
events.append(event)
|
| 172 |
+
|
| 173 |
+
assert len(events) > 0
|
| 174 |
+
event_types = [e.type for e in events]
|
| 175 |
+
assert "started" in event_types
|
| 176 |
+
assert "complete" in event_types
|
| 177 |
+
|
| 178 |
+
@pytest.mark.asyncio
|
| 179 |
+
async def test_graph_orchestrator_deep_mode(self):
|
| 180 |
+
"""GraphOrchestrator should run in deep mode."""
|
| 181 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 182 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 183 |
+
|
| 184 |
+
orchestrator = create_graph_orchestrator(
|
| 185 |
+
mode="deep",
|
| 186 |
+
max_iterations=1,
|
| 187 |
+
max_time_minutes=3,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
events = []
|
| 191 |
+
async for event in orchestrator.run("What are the main features of Python?"):
|
| 192 |
+
events.append(event)
|
| 193 |
+
|
| 194 |
+
assert len(events) > 0
|
| 195 |
+
event_types = [e.type for e in events]
|
| 196 |
+
assert "started" in event_types
|
| 197 |
+
assert "complete" in event_types
|
| 198 |
+
|
| 199 |
+
@pytest.mark.asyncio
|
| 200 |
+
async def test_graph_orchestrator_auto_mode(self):
|
| 201 |
+
"""GraphOrchestrator should auto-detect research mode."""
|
| 202 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 203 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 204 |
+
|
| 205 |
+
orchestrator = create_graph_orchestrator(
|
| 206 |
+
mode="auto",
|
| 207 |
+
max_iterations=1,
|
| 208 |
+
max_time_minutes=2,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
events = []
|
| 212 |
+
async for event in orchestrator.run("What is Python?"):
|
| 213 |
+
events.append(event)
|
| 214 |
+
|
| 215 |
+
assert len(events) > 0
|
| 216 |
+
# Should complete successfully regardless of mode
|
| 217 |
+
event_types = [e.type for e in events]
|
| 218 |
+
assert "complete" in event_types
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@pytest.mark.integration
|
| 222 |
+
class TestGraphOrchestrationIntegration:
|
| 223 |
+
"""Integration tests for graph-based orchestration with real API calls."""
|
| 224 |
+
|
| 225 |
+
@pytest.mark.asyncio
|
| 226 |
+
async def test_iterative_flow_with_graph_execution(self):
|
| 227 |
+
"""IterativeResearchFlow should work with graph execution enabled."""
|
| 228 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 229 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 230 |
+
|
| 231 |
+
flow = create_iterative_flow(
|
| 232 |
+
max_iterations=1,
|
| 233 |
+
max_time_minutes=2,
|
| 234 |
+
use_graph=True,
|
| 235 |
+
)
|
| 236 |
+
result = await flow.run(query="What is the capital of France?")
|
| 237 |
+
|
| 238 |
+
assert isinstance(result, str)
|
| 239 |
+
assert len(result) > 0
|
| 240 |
+
# Should mention Paris
|
| 241 |
+
assert "paris" in result.lower() or "france" in result.lower()
|
| 242 |
+
|
| 243 |
+
@pytest.mark.asyncio
|
| 244 |
+
async def test_deep_flow_with_graph_execution(self):
|
| 245 |
+
"""DeepResearchFlow should work with graph execution enabled."""
|
| 246 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 247 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 248 |
+
|
| 249 |
+
flow = create_deep_flow(
|
| 250 |
+
max_iterations=1,
|
| 251 |
+
max_time_minutes=3,
|
| 252 |
+
use_graph=True,
|
| 253 |
+
)
|
| 254 |
+
result = await flow.run("What are the main features of Python programming language?")
|
| 255 |
+
|
| 256 |
+
assert isinstance(result, str)
|
| 257 |
+
assert len(result) > 100 # Should have substantial content
|
| 258 |
+
|
| 259 |
+
@pytest.mark.asyncio
|
| 260 |
+
async def test_graph_orchestrator_with_graph_execution(self):
|
| 261 |
+
"""GraphOrchestrator should work with graph execution enabled."""
|
| 262 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 263 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 264 |
+
|
| 265 |
+
orchestrator = create_graph_orchestrator(
|
| 266 |
+
mode="iterative",
|
| 267 |
+
max_iterations=1,
|
| 268 |
+
max_time_minutes=2,
|
| 269 |
+
use_graph=True,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
events = []
|
| 273 |
+
async for event in orchestrator.run("What is Python?"):
|
| 274 |
+
events.append(event)
|
| 275 |
+
|
| 276 |
+
assert len(events) > 0
|
| 277 |
+
event_types = [e.type for e in events]
|
| 278 |
+
assert "started" in event_types
|
| 279 |
+
assert "complete" in event_types
|
| 280 |
+
|
| 281 |
+
# Extract final report from complete event
|
| 282 |
+
complete_events = [e for e in events if e.type == "complete"]
|
| 283 |
+
assert len(complete_events) > 0
|
| 284 |
+
final_report = complete_events[0].message
|
| 285 |
+
assert isinstance(final_report, str)
|
| 286 |
+
assert len(final_report) > 0
|
| 287 |
+
|
| 288 |
+
@pytest.mark.asyncio
|
| 289 |
+
async def test_graph_orchestrator_parallel_execution(self):
|
| 290 |
+
"""GraphOrchestrator should support parallel execution in deep mode."""
|
| 291 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 292 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 293 |
+
|
| 294 |
+
orchestrator = create_graph_orchestrator(
|
| 295 |
+
mode="deep",
|
| 296 |
+
max_iterations=1,
|
| 297 |
+
max_time_minutes=3,
|
| 298 |
+
use_graph=True,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
events = []
|
| 302 |
+
async for event in orchestrator.run("What are the main features of Python?"):
|
| 303 |
+
events.append(event)
|
| 304 |
+
|
| 305 |
+
assert len(events) > 0
|
| 306 |
+
event_types = [e.type for e in events]
|
| 307 |
+
assert "started" in event_types
|
| 308 |
+
assert "complete" in event_types
|
| 309 |
+
|
| 310 |
+
@pytest.mark.asyncio
|
| 311 |
+
async def test_graph_vs_chain_execution_comparison(self):
|
| 312 |
+
"""Both graph and chain execution should produce similar results."""
|
| 313 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 314 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 315 |
+
|
| 316 |
+
query = "What is the capital of France?"
|
| 317 |
+
|
| 318 |
+
# Run with graph execution
|
| 319 |
+
flow_graph = create_iterative_flow(
|
| 320 |
+
max_iterations=1,
|
| 321 |
+
max_time_minutes=2,
|
| 322 |
+
use_graph=True,
|
| 323 |
+
)
|
| 324 |
+
result_graph = await flow_graph.run(query=query)
|
| 325 |
+
|
| 326 |
+
# Run with agent chains
|
| 327 |
+
flow_chains = create_iterative_flow(
|
| 328 |
+
max_iterations=1,
|
| 329 |
+
max_time_minutes=2,
|
| 330 |
+
use_graph=False,
|
| 331 |
+
)
|
| 332 |
+
result_chains = await flow_chains.run(query=query)
|
| 333 |
+
|
| 334 |
+
# Both should produce valid results
|
| 335 |
+
assert isinstance(result_graph, str)
|
| 336 |
+
assert isinstance(result_chains, str)
|
| 337 |
+
assert len(result_graph) > 0
|
| 338 |
+
assert len(result_chains) > 0
|
| 339 |
+
|
| 340 |
+
# Both should mention the answer (Paris)
|
| 341 |
+
assert "paris" in result_graph.lower() or "france" in result_graph.lower()
|
| 342 |
+
assert "paris" in result_chains.lower() or "france" in result_chains.lower()
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
@pytest.mark.integration
|
| 346 |
+
class TestReportSynthesisIntegration:
|
| 347 |
+
"""Integration tests for report synthesis with writer agents."""
|
| 348 |
+
|
| 349 |
+
@pytest.mark.asyncio
|
| 350 |
+
async def test_iterative_flow_generates_report(self):
|
| 351 |
+
"""IterativeResearchFlow should generate a report with writer agent."""
|
| 352 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 353 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 354 |
+
|
| 355 |
+
flow = create_iterative_flow(max_iterations=1, max_time_minutes=2)
|
| 356 |
+
result = await flow.run(
|
| 357 |
+
query="What is the capital of France?",
|
| 358 |
+
output_length="A short paragraph",
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
assert isinstance(result, str)
|
| 362 |
+
assert len(result) > 0
|
| 363 |
+
# Should be a formatted report
|
| 364 |
+
assert "paris" in result.lower() or "france" in result.lower()
|
| 365 |
+
# Should have some structure (markdown headers or content)
|
| 366 |
+
assert len(result) > 50
|
| 367 |
+
|
| 368 |
+
@pytest.mark.asyncio
|
| 369 |
+
async def test_iterative_flow_includes_citations(self):
|
| 370 |
+
"""IterativeResearchFlow should include citations in the report."""
|
| 371 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 372 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 373 |
+
|
| 374 |
+
flow = create_iterative_flow(max_iterations=1, max_time_minutes=2)
|
| 375 |
+
result = await flow.run(
|
| 376 |
+
query="What is machine learning?",
|
| 377 |
+
output_length="A short paragraph",
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
assert isinstance(result, str)
|
| 381 |
+
# Should have some form of citations or references
|
| 382 |
+
# (either [1], [2] format or References section)
|
| 383 |
+
# Note: Citations may not always be present depending on findings
|
| 384 |
+
# This is a soft check - just verify report was generated
|
| 385 |
+
assert len(result) > 0
|
| 386 |
+
|
| 387 |
+
@pytest.mark.asyncio
|
| 388 |
+
async def test_iterative_flow_handles_empty_findings(self):
|
| 389 |
+
"""IterativeResearchFlow should handle empty findings gracefully."""
|
| 390 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 391 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 392 |
+
|
| 393 |
+
flow = create_iterative_flow(max_iterations=1, max_time_minutes=1)
|
| 394 |
+
# Use a query that might not return findings quickly
|
| 395 |
+
result = await flow.run(
|
| 396 |
+
query="Test query with no findings",
|
| 397 |
+
output_length="A short paragraph",
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# Should still return a report (even if minimal)
|
| 401 |
+
assert isinstance(result, str)
|
| 402 |
+
# Writer agent should handle empty findings with fallback
|
| 403 |
+
|
| 404 |
+
@pytest.mark.asyncio
|
| 405 |
+
async def test_deep_flow_with_long_writer(self):
|
| 406 |
+
"""DeepResearchFlow should use long writer to create sections."""
|
| 407 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 408 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 409 |
+
|
| 410 |
+
flow = create_deep_flow(
|
| 411 |
+
max_iterations=1,
|
| 412 |
+
max_time_minutes=3,
|
| 413 |
+
use_long_writer=True,
|
| 414 |
+
)
|
| 415 |
+
result = await flow.run("What are the main features of Python programming language?")
|
| 416 |
+
|
| 417 |
+
assert isinstance(result, str)
|
| 418 |
+
assert len(result) > 100 # Should have substantial content
|
| 419 |
+
# Should have section structure (table of contents or sections)
|
| 420 |
+
has_structure = (
|
| 421 |
+
"##" in result
|
| 422 |
+
or "#" in result
|
| 423 |
+
or "table of contents" in result.lower()
|
| 424 |
+
or "introduction" in result.lower()
|
| 425 |
+
)
|
| 426 |
+
# Long writer should create structured report
|
| 427 |
+
assert has_structure or len(result) > 200
|
| 428 |
+
|
| 429 |
+
@pytest.mark.asyncio
|
| 430 |
+
async def test_deep_flow_creates_sections(self):
|
| 431 |
+
"""DeepResearchFlow should create multiple sections in the report."""
|
| 432 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 433 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 434 |
+
|
| 435 |
+
flow = create_deep_flow(
|
| 436 |
+
max_iterations=1,
|
| 437 |
+
max_time_minutes=3,
|
| 438 |
+
use_long_writer=True,
|
| 439 |
+
)
|
| 440 |
+
result = await flow.run("Explain the basics of quantum computing")
|
| 441 |
+
|
| 442 |
+
assert isinstance(result, str)
|
| 443 |
+
# Should have multiple sections (indicated by headers)
|
| 444 |
+
# Should have at least some structure
|
| 445 |
+
assert len(result) > 100
|
| 446 |
+
|
| 447 |
+
@pytest.mark.asyncio
|
| 448 |
+
async def test_deep_flow_aggregates_references(self):
|
| 449 |
+
"""DeepResearchFlow should aggregate references from all sections."""
|
| 450 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 451 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 452 |
+
|
| 453 |
+
flow = create_deep_flow(
|
| 454 |
+
max_iterations=1,
|
| 455 |
+
max_time_minutes=3,
|
| 456 |
+
use_long_writer=True,
|
| 457 |
+
)
|
| 458 |
+
result = await flow.run("What are the main features of Python programming language?")
|
| 459 |
+
|
| 460 |
+
assert isinstance(result, str)
|
| 461 |
+
# Long writer should aggregate references at the end
|
| 462 |
+
# Check for references section or citation format
|
| 463 |
+
# Note: References may not always be present
|
| 464 |
+
# Just verify report structure is correct
|
| 465 |
+
assert len(result) > 100
|
| 466 |
+
|
| 467 |
+
@pytest.mark.asyncio
|
| 468 |
+
async def test_deep_flow_with_proofreader(self):
|
| 469 |
+
"""DeepResearchFlow should use proofreader to finalize report."""
|
| 470 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 471 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 472 |
+
|
| 473 |
+
flow = create_deep_flow(
|
| 474 |
+
max_iterations=1,
|
| 475 |
+
max_time_minutes=3,
|
| 476 |
+
use_long_writer=False, # Use proofreader instead
|
| 477 |
+
)
|
| 478 |
+
result = await flow.run("What is artificial intelligence?")
|
| 479 |
+
|
| 480 |
+
assert isinstance(result, str)
|
| 481 |
+
assert len(result) > 0
|
| 482 |
+
# Proofreader should create polished report
|
| 483 |
+
# Should have some structure
|
| 484 |
+
assert len(result) > 50
|
| 485 |
+
|
| 486 |
+
@pytest.mark.asyncio
|
| 487 |
+
async def test_proofreader_removes_duplicates(self):
|
| 488 |
+
"""Proofreader should remove duplicate content from report."""
|
| 489 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 490 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 491 |
+
|
| 492 |
+
flow = create_deep_flow(
|
| 493 |
+
max_iterations=1,
|
| 494 |
+
max_time_minutes=3,
|
| 495 |
+
use_long_writer=False,
|
| 496 |
+
)
|
| 497 |
+
result = await flow.run("Explain machine learning basics")
|
| 498 |
+
|
| 499 |
+
assert isinstance(result, str)
|
| 500 |
+
# Proofreader should create polished, non-repetitive content
|
| 501 |
+
# This is a soft check - just verify report was generated
|
| 502 |
+
assert len(result) > 0
|
| 503 |
+
|
| 504 |
+
@pytest.mark.asyncio
|
| 505 |
+
async def test_proofreader_adds_summary(self):
|
| 506 |
+
"""Proofreader should add a summary to the report."""
|
| 507 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 508 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 509 |
+
|
| 510 |
+
flow = create_deep_flow(
|
| 511 |
+
max_iterations=1,
|
| 512 |
+
max_time_minutes=3,
|
| 513 |
+
use_long_writer=False,
|
| 514 |
+
)
|
| 515 |
+
result = await flow.run("What is Python programming language?")
|
| 516 |
+
|
| 517 |
+
assert isinstance(result, str)
|
| 518 |
+
# Proofreader should add summary/outline
|
| 519 |
+
# Check for summary indicators
|
| 520 |
+
# Note: Summary format may vary
|
| 521 |
+
# Just verify report was generated
|
| 522 |
+
assert len(result) > 0
|
| 523 |
+
|
| 524 |
+
@pytest.mark.asyncio
|
| 525 |
+
async def test_graph_orchestrator_uses_writer_agents(self):
|
| 526 |
+
"""GraphOrchestrator should use writer agents in iterative mode."""
|
| 527 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 528 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 529 |
+
|
| 530 |
+
orchestrator = create_graph_orchestrator(
|
| 531 |
+
mode="iterative",
|
| 532 |
+
max_iterations=1,
|
| 533 |
+
max_time_minutes=2,
|
| 534 |
+
use_graph=False, # Use agent chains to test writer integration
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
events = []
|
| 538 |
+
async for event in orchestrator.run("What is the capital of France?"):
|
| 539 |
+
events.append(event)
|
| 540 |
+
|
| 541 |
+
assert len(events) > 0
|
| 542 |
+
event_types = [e.type for e in events]
|
| 543 |
+
assert "started" in event_types
|
| 544 |
+
assert "complete" in event_types
|
| 545 |
+
|
| 546 |
+
# Extract final report from complete event
|
| 547 |
+
complete_events = [e for e in events if e.type == "complete"]
|
| 548 |
+
assert len(complete_events) > 0
|
| 549 |
+
final_report = complete_events[0].message
|
| 550 |
+
assert isinstance(final_report, str)
|
| 551 |
+
assert len(final_report) > 0
|
| 552 |
+
# Should have content from writer agent
|
| 553 |
+
assert "paris" in final_report.lower() or "france" in final_report.lower()
|
| 554 |
+
|
| 555 |
+
@pytest.mark.asyncio
|
| 556 |
+
async def test_graph_orchestrator_uses_long_writer_in_deep_mode(self):
|
| 557 |
+
"""GraphOrchestrator should use long writer in deep mode."""
|
| 558 |
+
if not settings.has_openai_key and not settings.has_anthropic_key:
|
| 559 |
+
pytest.skip("No OpenAI or Anthropic API key available")
|
| 560 |
+
|
| 561 |
+
orchestrator = create_graph_orchestrator(
|
| 562 |
+
mode="deep",
|
| 563 |
+
max_iterations=1,
|
| 564 |
+
max_time_minutes=3,
|
| 565 |
+
use_graph=False, # Use agent chains
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
events = []
|
| 569 |
+
async for event in orchestrator.run("What are the main features of Python?"):
|
| 570 |
+
events.append(event)
|
| 571 |
+
|
| 572 |
+
assert len(events) > 0
|
| 573 |
+
event_types = [e.type for e in events]
|
| 574 |
+
assert "started" in event_types
|
| 575 |
+
assert "complete" in event_types
|
| 576 |
+
|
| 577 |
+
# Extract final report
|
| 578 |
+
complete_events = [e for e in events if e.type == "complete"]
|
| 579 |
+
assert len(complete_events) > 0
|
| 580 |
+
final_report = complete_events[0].message
|
| 581 |
+
assert isinstance(final_report, str)
|
| 582 |
+
assert len(final_report) > 0
|
| 583 |
+
# Should have structured content from long writer
|
| 584 |
+
assert len(final_report) > 100
|
tests/unit/agent_factory/test_graph_builder.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for graph builder utilities."""
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
from unittest.mock import MagicMock
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
from pydantic_ai import Agent
|
| 8 |
+
|
| 9 |
+
from src.agent_factory.graph_builder import (
|
| 10 |
+
AgentNode,
|
| 11 |
+
ConditionalEdge,
|
| 12 |
+
DecisionNode,
|
| 13 |
+
GraphBuilder,
|
| 14 |
+
GraphNode,
|
| 15 |
+
ParallelNode,
|
| 16 |
+
ResearchGraph,
|
| 17 |
+
SequentialEdge,
|
| 18 |
+
StateNode,
|
| 19 |
+
create_deep_graph,
|
| 20 |
+
create_iterative_graph,
|
| 21 |
+
)
|
| 22 |
+
from src.middleware.state_machine import WorkflowState
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TestGraphNode:
|
| 26 |
+
"""Tests for GraphNode models."""
|
| 27 |
+
|
| 28 |
+
def test_graph_node_creation(self):
|
| 29 |
+
"""Test creating a base GraphNode."""
|
| 30 |
+
node = GraphNode(node_id="test_node", node_type="agent", description="Test")
|
| 31 |
+
assert node.node_id == "test_node"
|
| 32 |
+
assert node.node_type == "agent"
|
| 33 |
+
assert node.description == "Test"
|
| 34 |
+
|
| 35 |
+
def test_agent_node_creation(self):
|
| 36 |
+
"""Test creating an AgentNode."""
|
| 37 |
+
mock_agent = MagicMock(spec=Agent)
|
| 38 |
+
node = AgentNode(
|
| 39 |
+
node_id="agent_1",
|
| 40 |
+
agent=mock_agent,
|
| 41 |
+
description="Test agent",
|
| 42 |
+
)
|
| 43 |
+
assert node.node_id == "agent_1"
|
| 44 |
+
assert node.node_type == "agent"
|
| 45 |
+
assert node.agent == mock_agent
|
| 46 |
+
assert node.input_transformer is None
|
| 47 |
+
assert node.output_transformer is None
|
| 48 |
+
|
| 49 |
+
def test_agent_node_with_transformers(self):
|
| 50 |
+
"""Test creating an AgentNode with transformers."""
|
| 51 |
+
mock_agent = MagicMock(spec=Agent)
|
| 52 |
+
|
| 53 |
+
def input_transformer(x):
|
| 54 |
+
return f"input_{x}"
|
| 55 |
+
|
| 56 |
+
def output_transformer(x):
|
| 57 |
+
return f"output_{x}"
|
| 58 |
+
|
| 59 |
+
node = AgentNode(
|
| 60 |
+
node_id="agent_1",
|
| 61 |
+
agent=mock_agent,
|
| 62 |
+
input_transformer=input_transformer,
|
| 63 |
+
output_transformer=output_transformer,
|
| 64 |
+
)
|
| 65 |
+
assert node.input_transformer is not None
|
| 66 |
+
assert node.output_transformer is not None
|
| 67 |
+
|
| 68 |
+
def test_state_node_creation(self):
|
| 69 |
+
"""Test creating a StateNode."""
|
| 70 |
+
|
| 71 |
+
def state_updater(state: WorkflowState, data: Any) -> WorkflowState:
|
| 72 |
+
return state
|
| 73 |
+
|
| 74 |
+
node = StateNode(
|
| 75 |
+
node_id="state_1",
|
| 76 |
+
state_updater=state_updater,
|
| 77 |
+
description="Test state",
|
| 78 |
+
)
|
| 79 |
+
assert node.node_id == "state_1"
|
| 80 |
+
assert node.node_type == "state"
|
| 81 |
+
assert node.state_updater is not None
|
| 82 |
+
assert node.state_reader is None
|
| 83 |
+
|
| 84 |
+
def test_decision_node_creation(self):
|
| 85 |
+
"""Test creating a DecisionNode."""
|
| 86 |
+
|
| 87 |
+
def decision_func(data: Any) -> str:
|
| 88 |
+
return "next_node"
|
| 89 |
+
|
| 90 |
+
node = DecisionNode(
|
| 91 |
+
node_id="decision_1",
|
| 92 |
+
decision_function=decision_func,
|
| 93 |
+
options=["next_node", "other_node"],
|
| 94 |
+
description="Test decision",
|
| 95 |
+
)
|
| 96 |
+
assert node.node_id == "decision_1"
|
| 97 |
+
assert node.node_type == "decision"
|
| 98 |
+
assert len(node.options) == 2
|
| 99 |
+
assert "next_node" in node.options
|
| 100 |
+
|
| 101 |
+
def test_parallel_node_creation(self):
|
| 102 |
+
"""Test creating a ParallelNode."""
|
| 103 |
+
node = ParallelNode(
|
| 104 |
+
node_id="parallel_1",
|
| 105 |
+
parallel_nodes=["node1", "node2", "node3"],
|
| 106 |
+
description="Test parallel",
|
| 107 |
+
)
|
| 108 |
+
assert node.node_id == "parallel_1"
|
| 109 |
+
assert node.node_type == "parallel"
|
| 110 |
+
assert len(node.parallel_nodes) == 3
|
| 111 |
+
assert node.aggregator is None
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class TestGraphEdge:
|
| 115 |
+
"""Tests for GraphEdge models."""
|
| 116 |
+
|
| 117 |
+
def test_sequential_edge_creation(self):
|
| 118 |
+
"""Test creating a SequentialEdge."""
|
| 119 |
+
edge = SequentialEdge(from_node="node1", to_node="node2")
|
| 120 |
+
assert edge.from_node == "node1"
|
| 121 |
+
assert edge.to_node == "node2"
|
| 122 |
+
assert edge.condition is None
|
| 123 |
+
assert edge.weight == 1.0
|
| 124 |
+
|
| 125 |
+
def test_conditional_edge_creation(self):
|
| 126 |
+
"""Test creating a ConditionalEdge."""
|
| 127 |
+
|
| 128 |
+
def condition(data: Any) -> bool:
|
| 129 |
+
return True
|
| 130 |
+
|
| 131 |
+
edge = ConditionalEdge(
|
| 132 |
+
from_node="node1",
|
| 133 |
+
to_node="node2",
|
| 134 |
+
condition=condition,
|
| 135 |
+
condition_description="Test condition",
|
| 136 |
+
)
|
| 137 |
+
assert edge.from_node == "node1"
|
| 138 |
+
assert edge.to_node == "node2"
|
| 139 |
+
assert edge.condition is not None
|
| 140 |
+
assert edge.condition_description == "Test condition"
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class TestResearchGraph:
|
| 144 |
+
"""Tests for ResearchGraph class."""
|
| 145 |
+
|
| 146 |
+
def test_graph_creation(self):
|
| 147 |
+
"""Test creating an empty graph."""
|
| 148 |
+
graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
|
| 149 |
+
assert graph.entry_node == "start"
|
| 150 |
+
assert len(graph.exit_nodes) == 1
|
| 151 |
+
assert graph.exit_nodes[0] == "end"
|
| 152 |
+
assert len(graph.nodes) == 0
|
| 153 |
+
assert len(graph.edges) == 0
|
| 154 |
+
|
| 155 |
+
def test_add_node(self):
|
| 156 |
+
"""Test adding a node to the graph."""
|
| 157 |
+
graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
|
| 158 |
+
node = GraphNode(node_id="node1", node_type="agent", description="Test")
|
| 159 |
+
graph.add_node(node)
|
| 160 |
+
assert "node1" in graph.nodes
|
| 161 |
+
assert graph.get_node("node1") == node
|
| 162 |
+
|
| 163 |
+
def test_add_node_duplicate_raises_error(self):
|
| 164 |
+
"""Test that adding duplicate node raises ValueError."""
|
| 165 |
+
graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
|
| 166 |
+
node = GraphNode(node_id="node1", node_type="agent", description="Test")
|
| 167 |
+
graph.add_node(node)
|
| 168 |
+
with pytest.raises(ValueError, match="already exists"):
|
| 169 |
+
graph.add_node(node)
|
| 170 |
+
|
| 171 |
+
def test_add_edge(self):
|
| 172 |
+
"""Test adding an edge to the graph."""
|
| 173 |
+
graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
|
| 174 |
+
node1 = GraphNode(node_id="node1", node_type="agent", description="Test")
|
| 175 |
+
node2 = GraphNode(node_id="node2", node_type="agent", description="Test")
|
| 176 |
+
graph.add_node(node1)
|
| 177 |
+
graph.add_node(node2)
|
| 178 |
+
|
| 179 |
+
edge = SequentialEdge(from_node="node1", to_node="node2")
|
| 180 |
+
graph.add_edge(edge)
|
| 181 |
+
assert "node1" in graph.edges
|
| 182 |
+
assert len(graph.edges["node1"]) == 1
|
| 183 |
+
assert graph.edges["node1"][0] == edge
|
| 184 |
+
|
| 185 |
+
def test_add_edge_invalid_source_raises_error(self):
|
| 186 |
+
"""Test that adding edge with invalid source raises ValueError."""
|
| 187 |
+
graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
|
| 188 |
+
edge = SequentialEdge(from_node="nonexistent", to_node="node2")
|
| 189 |
+
with pytest.raises(ValueError, match="Source node.*not found"):
|
| 190 |
+
graph.add_edge(edge)
|
| 191 |
+
|
| 192 |
+
def test_add_edge_invalid_target_raises_error(self):
|
| 193 |
+
"""Test that adding edge with invalid target raises ValueError."""
|
| 194 |
+
graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
|
| 195 |
+
node1 = GraphNode(node_id="node1", node_type="agent", description="Test")
|
| 196 |
+
graph.add_node(node1)
|
| 197 |
+
edge = SequentialEdge(from_node="node1", to_node="nonexistent")
|
| 198 |
+
with pytest.raises(ValueError, match="Target node.*not found"):
|
| 199 |
+
graph.add_edge(edge)
|
| 200 |
+
|
| 201 |
+
def test_get_next_nodes(self):
|
| 202 |
+
"""Test getting next nodes from a node."""
|
| 203 |
+
graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
|
| 204 |
+
node1 = GraphNode(node_id="node1", node_type="agent", description="Test")
|
| 205 |
+
node2 = GraphNode(node_id="node2", node_type="agent", description="Test")
|
| 206 |
+
graph.add_node(node1)
|
| 207 |
+
graph.add_node(node2)
|
| 208 |
+
graph.add_edge(SequentialEdge(from_node="node1", to_node="node2"))
|
| 209 |
+
|
| 210 |
+
next_nodes = graph.get_next_nodes("node1")
|
| 211 |
+
assert len(next_nodes) == 1
|
| 212 |
+
assert next_nodes[0][0] == "node2"
|
| 213 |
+
|
| 214 |
+
def test_get_next_nodes_with_condition(self):
|
| 215 |
+
"""Test getting next nodes with conditional edge."""
|
| 216 |
+
graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
|
| 217 |
+
node1 = GraphNode(node_id="node1", node_type="agent", description="Test")
|
| 218 |
+
node2 = GraphNode(node_id="node2", node_type="agent", description="Test")
|
| 219 |
+
node3 = GraphNode(node_id="node3", node_type="agent", description="Test")
|
| 220 |
+
graph.add_node(node1)
|
| 221 |
+
graph.add_node(node2)
|
| 222 |
+
graph.add_node(node3)
|
| 223 |
+
|
| 224 |
+
# Add conditional edge that only passes when data is True
|
| 225 |
+
def condition(data: Any) -> bool:
|
| 226 |
+
return data is True
|
| 227 |
+
|
| 228 |
+
graph.add_edge(SequentialEdge(from_node="node1", to_node="node2"))
|
| 229 |
+
graph.add_edge(ConditionalEdge(from_node="node1", to_node="node3", condition=condition))
|
| 230 |
+
|
| 231 |
+
# With condition True, should get both
|
| 232 |
+
next_nodes = graph.get_next_nodes("node1", context=True)
|
| 233 |
+
assert len(next_nodes) == 2
|
| 234 |
+
|
| 235 |
+
# With condition False, should only get sequential edge
|
| 236 |
+
next_nodes = graph.get_next_nodes("node1", context=False)
|
| 237 |
+
assert len(next_nodes) == 1
|
| 238 |
+
assert next_nodes[0][0] == "node2"
|
| 239 |
+
|
| 240 |
+
def test_validate_empty_graph(self):
|
| 241 |
+
"""Test validating an empty graph."""
|
| 242 |
+
graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
|
| 243 |
+
errors = graph.validate()
|
| 244 |
+
assert len(errors) > 0 # Should have errors for missing entry/exit nodes
|
| 245 |
+
|
| 246 |
+
def test_validate_valid_graph(self):
|
| 247 |
+
"""Test validating a valid graph."""
|
| 248 |
+
graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
|
| 249 |
+
start_node = GraphNode(node_id="start", node_type="agent", description="Start")
|
| 250 |
+
end_node = GraphNode(node_id="end", node_type="agent", description="End")
|
| 251 |
+
graph.add_node(start_node)
|
| 252 |
+
graph.add_node(end_node)
|
| 253 |
+
graph.add_edge(SequentialEdge(from_node="start", to_node="end"))
|
| 254 |
+
|
| 255 |
+
errors = graph.validate()
|
| 256 |
+
assert len(errors) == 0
|
| 257 |
+
|
| 258 |
+
def test_validate_unreachable_nodes(self):
|
| 259 |
+
"""Test that validation detects unreachable nodes."""
|
| 260 |
+
graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
|
| 261 |
+
start_node = GraphNode(node_id="start", node_type="agent", description="Start")
|
| 262 |
+
end_node = GraphNode(node_id="end", node_type="agent", description="End")
|
| 263 |
+
unreachable = GraphNode(node_id="unreachable", node_type="agent", description="Unreachable")
|
| 264 |
+
graph.add_node(start_node)
|
| 265 |
+
graph.add_node(end_node)
|
| 266 |
+
graph.add_node(unreachable)
|
| 267 |
+
graph.add_edge(SequentialEdge(from_node="start", to_node="end"))
|
| 268 |
+
|
| 269 |
+
errors = graph.validate()
|
| 270 |
+
assert len(errors) > 0
|
| 271 |
+
assert any("unreachable" in error.lower() for error in errors)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class TestGraphBuilder:
|
| 275 |
+
"""Tests for GraphBuilder class."""
|
| 276 |
+
|
| 277 |
+
def test_builder_initialization(self):
|
| 278 |
+
"""Test initializing a GraphBuilder."""
|
| 279 |
+
builder = GraphBuilder()
|
| 280 |
+
assert builder.graph is not None
|
| 281 |
+
assert builder.graph.entry_node == ""
|
| 282 |
+
assert len(builder.graph.exit_nodes) == 0
|
| 283 |
+
|
| 284 |
+
def test_add_agent_node(self):
|
| 285 |
+
"""Test adding an agent node."""
|
| 286 |
+
builder = GraphBuilder()
|
| 287 |
+
mock_agent = MagicMock(spec=Agent)
|
| 288 |
+
builder.add_agent_node("agent1", mock_agent, "Test agent")
|
| 289 |
+
assert "agent1" in builder.graph.nodes
|
| 290 |
+
node = builder.graph.get_node("agent1")
|
| 291 |
+
assert isinstance(node, AgentNode)
|
| 292 |
+
assert node.agent == mock_agent
|
| 293 |
+
|
| 294 |
+
def test_add_state_node(self):
|
| 295 |
+
"""Test adding a state node."""
|
| 296 |
+
builder = GraphBuilder()
|
| 297 |
+
|
| 298 |
+
def updater(state: WorkflowState, data: Any) -> WorkflowState:
|
| 299 |
+
return state
|
| 300 |
+
|
| 301 |
+
builder.add_state_node("state1", updater, "Test state")
|
| 302 |
+
assert "state1" in builder.graph.nodes
|
| 303 |
+
node = builder.graph.get_node("state1")
|
| 304 |
+
assert isinstance(node, StateNode)
|
| 305 |
+
|
| 306 |
+
def test_add_decision_node(self):
|
| 307 |
+
"""Test adding a decision node."""
|
| 308 |
+
builder = GraphBuilder()
|
| 309 |
+
|
| 310 |
+
def decision_func(data: Any) -> str:
|
| 311 |
+
return "next"
|
| 312 |
+
|
| 313 |
+
builder.add_decision_node("decision1", decision_func, ["next", "other"], "Test")
|
| 314 |
+
assert "decision1" in builder.graph.nodes
|
| 315 |
+
node = builder.graph.get_node("decision1")
|
| 316 |
+
assert isinstance(node, DecisionNode)
|
| 317 |
+
|
| 318 |
+
def test_add_parallel_node(self):
|
| 319 |
+
"""Test adding a parallel node."""
|
| 320 |
+
builder = GraphBuilder()
|
| 321 |
+
builder.add_parallel_node("parallel1", ["node1", "node2"], "Test")
|
| 322 |
+
assert "parallel1" in builder.graph.nodes
|
| 323 |
+
node = builder.graph.get_node("parallel1")
|
| 324 |
+
assert isinstance(node, ParallelNode)
|
| 325 |
+
assert len(node.parallel_nodes) == 2
|
| 326 |
+
|
| 327 |
+
def test_connect_nodes(self):
|
| 328 |
+
"""Test connecting nodes."""
|
| 329 |
+
builder = GraphBuilder()
|
| 330 |
+
builder.add_agent_node("node1", MagicMock(spec=Agent), "Node 1")
|
| 331 |
+
builder.add_agent_node("node2", MagicMock(spec=Agent), "Node 2")
|
| 332 |
+
builder.connect_nodes("node1", "node2")
|
| 333 |
+
assert "node1" in builder.graph.edges
|
| 334 |
+
assert len(builder.graph.edges["node1"]) == 1
|
| 335 |
+
|
| 336 |
+
def test_connect_nodes_with_condition(self):
|
| 337 |
+
"""Test connecting nodes with a condition."""
|
| 338 |
+
builder = GraphBuilder()
|
| 339 |
+
builder.add_agent_node("node1", MagicMock(spec=Agent), "Node 1")
|
| 340 |
+
builder.add_agent_node("node2", MagicMock(spec=Agent), "Node 2")
|
| 341 |
+
|
| 342 |
+
def condition(data: Any) -> bool:
|
| 343 |
+
return True
|
| 344 |
+
|
| 345 |
+
builder.connect_nodes("node1", "node2", condition=condition, condition_description="Test")
|
| 346 |
+
edge = builder.graph.edges["node1"][0]
|
| 347 |
+
assert isinstance(edge, ConditionalEdge)
|
| 348 |
+
assert edge.condition is not None
|
| 349 |
+
|
| 350 |
+
def test_set_entry_node(self):
|
| 351 |
+
"""Test setting entry node."""
|
| 352 |
+
builder = GraphBuilder()
|
| 353 |
+
builder.add_agent_node("start", MagicMock(spec=Agent), "Start")
|
| 354 |
+
builder.set_entry_node("start")
|
| 355 |
+
assert builder.graph.entry_node == "start"
|
| 356 |
+
|
| 357 |
+
def test_set_exit_nodes(self):
|
| 358 |
+
"""Test setting exit nodes."""
|
| 359 |
+
builder = GraphBuilder()
|
| 360 |
+
builder.add_agent_node("end1", MagicMock(spec=Agent), "End 1")
|
| 361 |
+
builder.add_agent_node("end2", MagicMock(spec=Agent), "End 2")
|
| 362 |
+
builder.set_exit_nodes(["end1", "end2"])
|
| 363 |
+
assert len(builder.graph.exit_nodes) == 2
|
| 364 |
+
|
| 365 |
+
def test_build_validates_graph(self):
|
| 366 |
+
"""Test that build() validates the graph."""
|
| 367 |
+
builder = GraphBuilder()
|
| 368 |
+
builder.add_agent_node("start", MagicMock(spec=Agent), "Start")
|
| 369 |
+
builder.set_entry_node("start")
|
| 370 |
+
# Missing exit node - should fail validation
|
| 371 |
+
with pytest.raises(ValueError, match="validation failed"):
|
| 372 |
+
builder.build()
|
| 373 |
+
|
| 374 |
+
def test_build_returns_valid_graph(self):
|
| 375 |
+
"""Test that build() returns a valid graph."""
|
| 376 |
+
builder = GraphBuilder()
|
| 377 |
+
mock_agent = MagicMock(spec=Agent)
|
| 378 |
+
builder.add_agent_node("start", mock_agent, "Start")
|
| 379 |
+
builder.add_agent_node("end", mock_agent, "End")
|
| 380 |
+
builder.connect_nodes("start", "end")
|
| 381 |
+
builder.set_entry_node("start")
|
| 382 |
+
builder.set_exit_nodes(["end"])
|
| 383 |
+
|
| 384 |
+
graph = builder.build()
|
| 385 |
+
assert isinstance(graph, ResearchGraph)
|
| 386 |
+
assert graph.entry_node == "start"
|
| 387 |
+
assert "end" in graph.exit_nodes
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class TestFactoryFunctions:
|
| 391 |
+
"""Tests for factory functions."""
|
| 392 |
+
|
| 393 |
+
def test_create_iterative_graph(self):
|
| 394 |
+
"""Test creating an iterative research graph."""
|
| 395 |
+
mock_kg_agent = MagicMock(spec=Agent)
|
| 396 |
+
mock_ts_agent = MagicMock(spec=Agent)
|
| 397 |
+
mock_thinking_agent = MagicMock(spec=Agent)
|
| 398 |
+
mock_writer_agent = MagicMock(spec=Agent)
|
| 399 |
+
|
| 400 |
+
graph = create_iterative_graph(
|
| 401 |
+
knowledge_gap_agent=mock_kg_agent,
|
| 402 |
+
tool_selector_agent=mock_ts_agent,
|
| 403 |
+
thinking_agent=mock_thinking_agent,
|
| 404 |
+
writer_agent=mock_writer_agent,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
assert isinstance(graph, ResearchGraph)
|
| 408 |
+
assert graph.entry_node == "thinking"
|
| 409 |
+
assert "writer" in graph.exit_nodes
|
| 410 |
+
assert "thinking" in graph.nodes
|
| 411 |
+
assert "knowledge_gap" in graph.nodes
|
| 412 |
+
assert "continue_decision" in graph.nodes
|
| 413 |
+
assert "tool_selector" in graph.nodes
|
| 414 |
+
assert "writer" in graph.nodes
|
| 415 |
+
|
| 416 |
+
def test_create_deep_graph(self):
|
| 417 |
+
"""Test creating a deep research graph."""
|
| 418 |
+
mock_planner_agent = MagicMock(spec=Agent)
|
| 419 |
+
mock_kg_agent = MagicMock(spec=Agent)
|
| 420 |
+
mock_ts_agent = MagicMock(spec=Agent)
|
| 421 |
+
mock_thinking_agent = MagicMock(spec=Agent)
|
| 422 |
+
mock_writer_agent = MagicMock(spec=Agent)
|
| 423 |
+
mock_long_writer_agent = MagicMock(spec=Agent)
|
| 424 |
+
|
| 425 |
+
graph = create_deep_graph(
|
| 426 |
+
planner_agent=mock_planner_agent,
|
| 427 |
+
knowledge_gap_agent=mock_kg_agent,
|
| 428 |
+
tool_selector_agent=mock_ts_agent,
|
| 429 |
+
thinking_agent=mock_thinking_agent,
|
| 430 |
+
writer_agent=mock_writer_agent,
|
| 431 |
+
long_writer_agent=mock_long_writer_agent,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
assert isinstance(graph, ResearchGraph)
|
| 435 |
+
assert graph.entry_node == "planner"
|
| 436 |
+
assert "synthesizer" in graph.exit_nodes
|
| 437 |
+
assert "planner" in graph.nodes
|
| 438 |
+
assert "parallel_loops_placeholder" in graph.nodes
|
| 439 |
+
assert "synthesizer" in graph.nodes
|
tests/unit/agents/test_input_parser.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for InputParserAgent."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
from pydantic_ai import AgentRunResult
|
| 7 |
+
|
| 8 |
+
from src.agents.input_parser import InputParserAgent, create_input_parser_agent
|
| 9 |
+
from src.utils.exceptions import ConfigurationError
|
| 10 |
+
from src.utils.models import ParsedQuery
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@pytest.fixture
|
| 14 |
+
def mock_model() -> MagicMock:
|
| 15 |
+
"""Create a mock Pydantic AI model."""
|
| 16 |
+
model = MagicMock()
|
| 17 |
+
model.name = "test-model"
|
| 18 |
+
return model
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@pytest.fixture
|
| 22 |
+
def mock_parsed_query_iterative() -> ParsedQuery:
|
| 23 |
+
"""Create a mock ParsedQuery for iterative mode."""
|
| 24 |
+
return ParsedQuery(
|
| 25 |
+
original_query="What is the mechanism of metformin?",
|
| 26 |
+
improved_query="What is the molecular mechanism of action of metformin in diabetes treatment?",
|
| 27 |
+
research_mode="iterative",
|
| 28 |
+
key_entities=["metformin", "diabetes"],
|
| 29 |
+
research_questions=["What is metformin's mechanism of action?"],
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@pytest.fixture
|
| 34 |
+
def mock_parsed_query_deep() -> ParsedQuery:
|
| 35 |
+
"""Create a mock ParsedQuery for deep mode."""
|
| 36 |
+
return ParsedQuery(
|
| 37 |
+
original_query="Write a comprehensive report on diabetes treatment",
|
| 38 |
+
improved_query="Provide a comprehensive analysis of diabetes treatment options, including mechanisms, clinical evidence, and market analysis",
|
| 39 |
+
research_mode="deep",
|
| 40 |
+
key_entities=["diabetes", "treatment"],
|
| 41 |
+
research_questions=[
|
| 42 |
+
"What are the main treatment options for diabetes?",
|
| 43 |
+
"What is the clinical evidence for each treatment?",
|
| 44 |
+
"What is the market size for diabetes treatments?",
|
| 45 |
+
],
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@pytest.fixture
|
| 50 |
+
def mock_agent_result_iterative(
|
| 51 |
+
mock_parsed_query_iterative: ParsedQuery,
|
| 52 |
+
) -> AgentRunResult[ParsedQuery]:
|
| 53 |
+
"""Create a mock agent result for iterative mode."""
|
| 54 |
+
result = MagicMock(spec=AgentRunResult)
|
| 55 |
+
result.output = mock_parsed_query_iterative
|
| 56 |
+
return result
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@pytest.fixture
|
| 60 |
+
def mock_agent_result_deep(
|
| 61 |
+
mock_parsed_query_deep: ParsedQuery,
|
| 62 |
+
) -> AgentRunResult[ParsedQuery]:
|
| 63 |
+
"""Create a mock agent result for deep mode."""
|
| 64 |
+
result = MagicMock(spec=AgentRunResult)
|
| 65 |
+
result.output = mock_parsed_query_deep
|
| 66 |
+
return result
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@pytest.fixture
|
| 70 |
+
def input_parser_agent(mock_model: MagicMock) -> InputParserAgent:
|
| 71 |
+
"""Create an InputParserAgent instance with mocked model."""
|
| 72 |
+
return InputParserAgent(model=mock_model)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TestInputParserAgentInit:
|
| 76 |
+
"""Test InputParserAgent initialization."""
|
| 77 |
+
|
| 78 |
+
def test_input_parser_agent_init_with_model(self, mock_model: MagicMock) -> None:
|
| 79 |
+
"""Test InputParserAgent initialization with provided model."""
|
| 80 |
+
agent = InputParserAgent(model=mock_model)
|
| 81 |
+
assert agent.model == mock_model
|
| 82 |
+
assert agent.agent is not None
|
| 83 |
+
|
| 84 |
+
@patch("src.agents.input_parser.get_model")
|
| 85 |
+
def test_input_parser_agent_init_without_model(
|
| 86 |
+
self, mock_get_model: MagicMock, mock_model: MagicMock
|
| 87 |
+
) -> None:
|
| 88 |
+
"""Test InputParserAgent initialization without model (uses default)."""
|
| 89 |
+
mock_get_model.return_value = mock_model
|
| 90 |
+
agent = InputParserAgent()
|
| 91 |
+
assert agent.model == mock_model
|
| 92 |
+
mock_get_model.assert_called_once()
|
| 93 |
+
|
| 94 |
+
def test_input_parser_agent_has_correct_system_prompt(
|
| 95 |
+
self, input_parser_agent: InputParserAgent
|
| 96 |
+
) -> None:
|
| 97 |
+
"""Test that InputParserAgent has correct system prompt."""
|
| 98 |
+
# System prompt should contain key instructions
|
| 99 |
+
# In pydantic_ai, system_prompt is a property that returns the prompt string
|
| 100 |
+
# For mocked agents, we check that the agent was created with a system prompt
|
| 101 |
+
assert input_parser_agent.agent is not None
|
| 102 |
+
# The actual system prompt is set during agent creation
|
| 103 |
+
# We verify the agent exists and was properly initialized
|
| 104 |
+
# Note: Direct access to system_prompt may not work with mocks
|
| 105 |
+
# This test verifies the agent structure is correct
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class TestParse:
|
| 109 |
+
"""Test parse() method."""
|
| 110 |
+
|
| 111 |
+
@pytest.mark.asyncio
|
| 112 |
+
async def test_parse_iterative_query(
|
| 113 |
+
self,
|
| 114 |
+
input_parser_agent: InputParserAgent,
|
| 115 |
+
mock_agent_result_iterative: AgentRunResult[ParsedQuery],
|
| 116 |
+
) -> None:
|
| 117 |
+
"""Test parsing a simple query that should return iterative mode."""
|
| 118 |
+
input_parser_agent.agent.run = AsyncMock(return_value=mock_agent_result_iterative)
|
| 119 |
+
|
| 120 |
+
query = "What is the mechanism of metformin?"
|
| 121 |
+
result = await input_parser_agent.parse(query)
|
| 122 |
+
|
| 123 |
+
assert isinstance(result, ParsedQuery)
|
| 124 |
+
assert result.research_mode == "iterative"
|
| 125 |
+
assert result.original_query == query
|
| 126 |
+
assert "metformin" in result.key_entities
|
| 127 |
+
assert input_parser_agent.agent.run.called
|
| 128 |
+
|
| 129 |
+
@pytest.mark.asyncio
|
| 130 |
+
async def test_parse_deep_query(
|
| 131 |
+
self,
|
| 132 |
+
input_parser_agent: InputParserAgent,
|
| 133 |
+
mock_agent_result_deep: AgentRunResult[ParsedQuery],
|
| 134 |
+
) -> None:
|
| 135 |
+
"""Test parsing a complex query that should return deep mode."""
|
| 136 |
+
input_parser_agent.agent.run = AsyncMock(return_value=mock_agent_result_deep)
|
| 137 |
+
|
| 138 |
+
query = "Write a comprehensive report on diabetes treatment"
|
| 139 |
+
result = await input_parser_agent.parse(query)
|
| 140 |
+
|
| 141 |
+
assert isinstance(result, ParsedQuery)
|
| 142 |
+
assert result.research_mode == "deep"
|
| 143 |
+
assert result.original_query == query
|
| 144 |
+
assert len(result.research_questions) > 0
|
| 145 |
+
assert input_parser_agent.agent.run.called
|
| 146 |
+
|
| 147 |
+
@pytest.mark.asyncio
|
| 148 |
+
async def test_parse_improves_query(
|
| 149 |
+
self,
|
| 150 |
+
input_parser_agent: InputParserAgent,
|
| 151 |
+
mock_agent_result_iterative: AgentRunResult[ParsedQuery],
|
| 152 |
+
) -> None:
|
| 153 |
+
"""Test that parse() improves the query."""
|
| 154 |
+
input_parser_agent.agent.run = AsyncMock(return_value=mock_agent_result_iterative)
|
| 155 |
+
|
| 156 |
+
query = "metformin mechanism"
|
| 157 |
+
result = await input_parser_agent.parse(query)
|
| 158 |
+
|
| 159 |
+
assert isinstance(result, ParsedQuery)
|
| 160 |
+
assert result.improved_query != result.original_query
|
| 161 |
+
assert len(result.improved_query) >= len(result.original_query)
|
| 162 |
+
|
| 163 |
+
@pytest.mark.asyncio
|
| 164 |
+
async def test_parse_extracts_entities(
|
| 165 |
+
self,
|
| 166 |
+
input_parser_agent: InputParserAgent,
|
| 167 |
+
mock_agent_result_iterative: AgentRunResult[ParsedQuery],
|
| 168 |
+
) -> None:
|
| 169 |
+
"""Test that parse() extracts key entities."""
|
| 170 |
+
input_parser_agent.agent.run = AsyncMock(return_value=mock_agent_result_iterative)
|
| 171 |
+
|
| 172 |
+
query = "What is the mechanism of metformin?"
|
| 173 |
+
result = await input_parser_agent.parse(query)
|
| 174 |
+
|
| 175 |
+
assert isinstance(result, ParsedQuery)
|
| 176 |
+
assert len(result.key_entities) > 0
|
| 177 |
+
assert "metformin" in result.key_entities
|
| 178 |
+
|
| 179 |
+
@pytest.mark.asyncio
|
| 180 |
+
async def test_parse_extracts_research_questions(
|
| 181 |
+
self,
|
| 182 |
+
input_parser_agent: InputParserAgent,
|
| 183 |
+
mock_agent_result_deep: AgentRunResult[ParsedQuery],
|
| 184 |
+
) -> None:
|
| 185 |
+
"""Test that parse() extracts research questions."""
|
| 186 |
+
input_parser_agent.agent.run = AsyncMock(return_value=mock_agent_result_deep)
|
| 187 |
+
|
| 188 |
+
query = "Write a comprehensive report on diabetes treatment"
|
| 189 |
+
result = await input_parser_agent.parse(query)
|
| 190 |
+
|
| 191 |
+
assert isinstance(result, ParsedQuery)
|
| 192 |
+
assert len(result.research_questions) > 0
|
| 193 |
+
|
| 194 |
+
@pytest.mark.asyncio
|
| 195 |
+
async def test_parse_handles_missing_improved_query(
|
| 196 |
+
self,
|
| 197 |
+
input_parser_agent: InputParserAgent,
|
| 198 |
+
mock_model: MagicMock,
|
| 199 |
+
) -> None:
|
| 200 |
+
"""Test that parse() handles missing improved_query gracefully."""
|
| 201 |
+
# Create a result with missing improved_query
|
| 202 |
+
mock_result = MagicMock(spec=AgentRunResult)
|
| 203 |
+
mock_parsed = ParsedQuery(
|
| 204 |
+
original_query="test query",
|
| 205 |
+
improved_query="", # Empty improved query
|
| 206 |
+
research_mode="iterative",
|
| 207 |
+
key_entities=[],
|
| 208 |
+
research_questions=[],
|
| 209 |
+
)
|
| 210 |
+
mock_result.output = mock_parsed
|
| 211 |
+
input_parser_agent.agent.run = AsyncMock(return_value=mock_result)
|
| 212 |
+
|
| 213 |
+
query = "test query"
|
| 214 |
+
result = await input_parser_agent.parse(query)
|
| 215 |
+
|
| 216 |
+
# Should use original_query as fallback
|
| 217 |
+
assert isinstance(result, ParsedQuery)
|
| 218 |
+
assert result.improved_query == result.original_query
|
| 219 |
+
|
| 220 |
+
@pytest.mark.asyncio
|
| 221 |
+
async def test_parse_fallback_to_heuristic_on_error(
|
| 222 |
+
self, input_parser_agent: InputParserAgent
|
| 223 |
+
) -> None:
|
| 224 |
+
"""Test that parse() falls back to heuristic when agent fails."""
|
| 225 |
+
# Make agent.run raise an exception
|
| 226 |
+
input_parser_agent.agent.run = AsyncMock(side_effect=Exception("Agent failed"))
|
| 227 |
+
|
| 228 |
+
# Query with "comprehensive" should trigger deep mode heuristic
|
| 229 |
+
query = "Write a comprehensive report on diabetes"
|
| 230 |
+
result = await input_parser_agent.parse(query)
|
| 231 |
+
|
| 232 |
+
assert isinstance(result, ParsedQuery)
|
| 233 |
+
assert result.research_mode == "deep" # Heuristic should detect "comprehensive"
|
| 234 |
+
assert result.original_query == query
|
| 235 |
+
assert result.improved_query == query # No improvement on fallback
|
| 236 |
+
|
| 237 |
+
@pytest.mark.asyncio
|
| 238 |
+
async def test_parse_heuristic_iterative_mode(
|
| 239 |
+
self, input_parser_agent: InputParserAgent
|
| 240 |
+
) -> None:
|
| 241 |
+
"""Test that parse() heuristic correctly identifies iterative mode."""
|
| 242 |
+
# Make agent.run raise an exception
|
| 243 |
+
input_parser_agent.agent.run = AsyncMock(side_effect=Exception("Agent failed"))
|
| 244 |
+
|
| 245 |
+
# Simple query should trigger iterative mode heuristic
|
| 246 |
+
query = "What is metformin?"
|
| 247 |
+
result = await input_parser_agent.parse(query)
|
| 248 |
+
|
| 249 |
+
assert isinstance(result, ParsedQuery)
|
| 250 |
+
assert result.research_mode == "iterative"
|
| 251 |
+
assert result.original_query == query
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class TestCreateInputParserAgent:
|
| 255 |
+
"""Test create_input_parser_agent() factory function."""
|
| 256 |
+
|
| 257 |
+
@patch("src.agents.input_parser.get_model")
|
| 258 |
+
def test_create_input_parser_agent_with_model(
|
| 259 |
+
self, mock_get_model: MagicMock, mock_model: MagicMock
|
| 260 |
+
) -> None:
|
| 261 |
+
"""Test factory function with provided model."""
|
| 262 |
+
agent = create_input_parser_agent(model=mock_model)
|
| 263 |
+
assert isinstance(agent, InputParserAgent)
|
| 264 |
+
assert agent.model == mock_model
|
| 265 |
+
mock_get_model.assert_not_called()
|
| 266 |
+
|
| 267 |
+
@patch("src.agents.input_parser.get_model")
|
| 268 |
+
def test_create_input_parser_agent_without_model(
|
| 269 |
+
self, mock_get_model: MagicMock, mock_model: MagicMock
|
| 270 |
+
) -> None:
|
| 271 |
+
"""Test factory function without model (uses default)."""
|
| 272 |
+
mock_get_model.return_value = mock_model
|
| 273 |
+
agent = create_input_parser_agent()
|
| 274 |
+
assert isinstance(agent, InputParserAgent)
|
| 275 |
+
assert agent.model == mock_model
|
| 276 |
+
mock_get_model.assert_called_once()
|
| 277 |
+
|
| 278 |
+
@patch("src.agents.input_parser.get_model")
|
| 279 |
+
def test_create_input_parser_agent_handles_error(self, mock_get_model: MagicMock) -> None:
|
| 280 |
+
"""Test factory function handles errors gracefully."""
|
| 281 |
+
mock_get_model.side_effect = Exception("Model creation failed")
|
| 282 |
+
with pytest.raises(ConfigurationError, match="Failed to create input parser agent"):
|
| 283 |
+
create_input_parser_agent()
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class TestResearchModeDetection:
|
| 287 |
+
"""Test research mode detection logic."""
|
| 288 |
+
|
| 289 |
+
@pytest.mark.asyncio
|
| 290 |
+
async def test_detects_iterative_mode_for_simple_queries(
|
| 291 |
+
self,
|
| 292 |
+
input_parser_agent: InputParserAgent,
|
| 293 |
+
mock_agent_result_iterative: AgentRunResult[ParsedQuery],
|
| 294 |
+
) -> None:
|
| 295 |
+
"""Test that simple queries are detected as iterative."""
|
| 296 |
+
input_parser_agent.agent.run = AsyncMock(return_value=mock_agent_result_iterative)
|
| 297 |
+
|
| 298 |
+
simple_queries = [
|
| 299 |
+
"What is the mechanism of metformin?",
|
| 300 |
+
"Find clinical trials for drug X",
|
| 301 |
+
"What is the capital of France?",
|
| 302 |
+
]
|
| 303 |
+
|
| 304 |
+
for query in simple_queries:
|
| 305 |
+
result = await input_parser_agent.parse(query)
|
| 306 |
+
assert result.research_mode == "iterative", f"Query '{query}' should be iterative"
|
| 307 |
+
|
| 308 |
+
@pytest.mark.asyncio
|
| 309 |
+
async def test_detects_deep_mode_for_complex_queries(
|
| 310 |
+
self,
|
| 311 |
+
input_parser_agent: InputParserAgent,
|
| 312 |
+
mock_agent_result_deep: AgentRunResult[ParsedQuery],
|
| 313 |
+
) -> None:
|
| 314 |
+
"""Test that complex queries are detected as deep."""
|
| 315 |
+
input_parser_agent.agent.run = AsyncMock(return_value=mock_agent_result_deep)
|
| 316 |
+
|
| 317 |
+
complex_queries = [
|
| 318 |
+
"Write a comprehensive report on diabetes treatment",
|
| 319 |
+
"Analyze the market for quantum computing",
|
| 320 |
+
"Provide a detailed analysis of AI trends",
|
| 321 |
+
]
|
| 322 |
+
|
| 323 |
+
for query in complex_queries:
|
| 324 |
+
result = await input_parser_agent.parse(query)
|
| 325 |
+
assert result.research_mode == "deep", f"Query '{query}' should be deep"
|
tests/unit/agents/test_long_writer.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for LongWriterAgent."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
from pydantic_ai import AgentResult
|
| 7 |
+
|
| 8 |
+
from src.agents.long_writer import LongWriterAgent, LongWriterOutput, create_long_writer_agent
|
| 9 |
+
from src.utils.models import ReportDraft, ReportDraftSection
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@pytest.fixture
|
| 13 |
+
def mock_model() -> MagicMock:
|
| 14 |
+
"""Create a mock Pydantic AI model."""
|
| 15 |
+
model = MagicMock()
|
| 16 |
+
model.name = "test-model"
|
| 17 |
+
return model
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@pytest.fixture
|
| 21 |
+
def mock_long_writer_output() -> LongWriterOutput:
|
| 22 |
+
"""Create a mock LongWriterOutput."""
|
| 23 |
+
return LongWriterOutput(
|
| 24 |
+
next_section_markdown="## Test Section\n\nContent with citation [1].",
|
| 25 |
+
references=["[1] https://example.com"],
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@pytest.fixture
|
| 30 |
+
def mock_agent_result(mock_long_writer_output: LongWriterOutput) -> AgentResult[LongWriterOutput]:
|
| 31 |
+
"""Create a mock agent result."""
|
| 32 |
+
result = MagicMock(spec=AgentResult)
|
| 33 |
+
result.output = mock_long_writer_output
|
| 34 |
+
return result
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@pytest.fixture
|
| 38 |
+
def long_writer_agent(mock_model: MagicMock) -> LongWriterAgent:
|
| 39 |
+
"""Create a LongWriterAgent instance with mocked model."""
|
| 40 |
+
return LongWriterAgent(model=mock_model)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@pytest.fixture
|
| 44 |
+
def sample_report_draft() -> ReportDraft:
|
| 45 |
+
"""Create a sample ReportDraft for testing."""
|
| 46 |
+
return ReportDraft(
|
| 47 |
+
sections=[
|
| 48 |
+
ReportDraftSection(
|
| 49 |
+
section_title="Introduction",
|
| 50 |
+
section_content="Introduction content with [1].",
|
| 51 |
+
),
|
| 52 |
+
ReportDraftSection(
|
| 53 |
+
section_title="Methods",
|
| 54 |
+
section_content="Methods content with [2].",
|
| 55 |
+
),
|
| 56 |
+
]
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class TestLongWriterAgentInit:
|
| 61 |
+
"""Test LongWriterAgent initialization."""
|
| 62 |
+
|
| 63 |
+
def test_long_writer_agent_init_with_model(self, mock_model: MagicMock) -> None:
|
| 64 |
+
"""Test LongWriterAgent initialization with provided model."""
|
| 65 |
+
agent = LongWriterAgent(model=mock_model)
|
| 66 |
+
assert agent.model == mock_model
|
| 67 |
+
assert agent.agent is not None
|
| 68 |
+
|
| 69 |
+
@patch("src.agents.long_writer.get_model")
|
| 70 |
+
def test_long_writer_agent_init_without_model(
|
| 71 |
+
self, mock_get_model: MagicMock, mock_model: MagicMock
|
| 72 |
+
) -> None:
|
| 73 |
+
"""Test LongWriterAgent initialization without model (uses default)."""
|
| 74 |
+
mock_get_model.return_value = mock_model
|
| 75 |
+
agent = LongWriterAgent()
|
| 76 |
+
assert agent.model == mock_model
|
| 77 |
+
mock_get_model.assert_called_once()
|
| 78 |
+
|
| 79 |
+
def test_long_writer_agent_has_structured_output(
|
| 80 |
+
self, long_writer_agent: LongWriterAgent
|
| 81 |
+
) -> None:
|
| 82 |
+
"""Test that LongWriterAgent uses structured output."""
|
| 83 |
+
assert long_writer_agent.agent.output_type == LongWriterOutput
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class TestWriteNextSection:
|
| 87 |
+
"""Test write_next_section() method."""
|
| 88 |
+
|
| 89 |
+
@pytest.mark.asyncio
|
| 90 |
+
async def test_write_next_section_basic(
|
| 91 |
+
self,
|
| 92 |
+
long_writer_agent: LongWriterAgent,
|
| 93 |
+
mock_agent_result: AgentResult[LongWriterOutput],
|
| 94 |
+
) -> None:
|
| 95 |
+
"""Test basic section writing."""
|
| 96 |
+
long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
|
| 97 |
+
|
| 98 |
+
original_query = "Test query"
|
| 99 |
+
report_draft = "## Existing Section\n\nContent"
|
| 100 |
+
next_section_title = "New Section"
|
| 101 |
+
next_section_draft = "Draft content"
|
| 102 |
+
|
| 103 |
+
result = await long_writer_agent.write_next_section(
|
| 104 |
+
original_query=original_query,
|
| 105 |
+
report_draft=report_draft,
|
| 106 |
+
next_section_title=next_section_title,
|
| 107 |
+
next_section_draft=next_section_draft,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
assert isinstance(result, LongWriterOutput)
|
| 111 |
+
assert result.next_section_markdown is not None
|
| 112 |
+
assert isinstance(result.references, list)
|
| 113 |
+
assert long_writer_agent.agent.run.called
|
| 114 |
+
|
| 115 |
+
@pytest.mark.asyncio
|
| 116 |
+
async def test_write_next_section_first_section(
|
| 117 |
+
self,
|
| 118 |
+
long_writer_agent: LongWriterAgent,
|
| 119 |
+
mock_agent_result: AgentResult[LongWriterOutput],
|
| 120 |
+
) -> None:
|
| 121 |
+
"""Test writing the first section (no existing draft)."""
|
| 122 |
+
long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
|
| 123 |
+
|
| 124 |
+
original_query = "Test query"
|
| 125 |
+
report_draft = "" # No existing draft
|
| 126 |
+
next_section_title = "First Section"
|
| 127 |
+
next_section_draft = "Draft content"
|
| 128 |
+
|
| 129 |
+
result = await long_writer_agent.write_next_section(
|
| 130 |
+
original_query=original_query,
|
| 131 |
+
report_draft=report_draft,
|
| 132 |
+
next_section_title=next_section_title,
|
| 133 |
+
next_section_draft=next_section_draft,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
assert isinstance(result, LongWriterOutput)
|
| 137 |
+
# Check that "No draft yet" was included in prompt
|
| 138 |
+
call_args = long_writer_agent.agent.run.call_args[0][0]
|
| 139 |
+
assert "No draft yet" in call_args or report_draft in call_args
|
| 140 |
+
|
| 141 |
+
@pytest.mark.asyncio
|
| 142 |
+
async def test_write_next_section_with_existing_draft(
|
| 143 |
+
self,
|
| 144 |
+
long_writer_agent: LongWriterAgent,
|
| 145 |
+
mock_agent_result: AgentResult[LongWriterOutput],
|
| 146 |
+
) -> None:
|
| 147 |
+
"""Test writing section with existing draft."""
|
| 148 |
+
long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
|
| 149 |
+
|
| 150 |
+
original_query = "Test query"
|
| 151 |
+
report_draft = "## Previous Section\n\nPrevious content"
|
| 152 |
+
next_section_title = "Next Section"
|
| 153 |
+
next_section_draft = "Next draft"
|
| 154 |
+
|
| 155 |
+
result = await long_writer_agent.write_next_section(
|
| 156 |
+
original_query=original_query,
|
| 157 |
+
report_draft=report_draft,
|
| 158 |
+
next_section_title=next_section_title,
|
| 159 |
+
next_section_draft=next_section_draft,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
assert isinstance(result, LongWriterOutput)
|
| 163 |
+
# Check that existing draft was included in prompt
|
| 164 |
+
call_args = long_writer_agent.agent.run.call_args[0][0]
|
| 165 |
+
assert "Previous Section" in call_args
|
| 166 |
+
|
| 167 |
+
@pytest.mark.asyncio
|
| 168 |
+
async def test_write_next_section_returns_references(
|
| 169 |
+
self,
|
| 170 |
+
long_writer_agent: LongWriterAgent,
|
| 171 |
+
mock_agent_result: AgentResult[LongWriterOutput],
|
| 172 |
+
) -> None:
|
| 173 |
+
"""Test that write_next_section returns references."""
|
| 174 |
+
long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
|
| 175 |
+
|
| 176 |
+
result = await long_writer_agent.write_next_section(
|
| 177 |
+
original_query="Test",
|
| 178 |
+
report_draft="",
|
| 179 |
+
next_section_title="Test",
|
| 180 |
+
next_section_draft="Test",
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
assert isinstance(result.references, list)
|
| 184 |
+
assert len(result.references) > 0
|
| 185 |
+
|
| 186 |
+
@pytest.mark.asyncio
|
| 187 |
+
async def test_write_next_section_handles_empty_draft(
|
| 188 |
+
self,
|
| 189 |
+
long_writer_agent: LongWriterAgent,
|
| 190 |
+
mock_agent_result: AgentResult[LongWriterOutput],
|
| 191 |
+
) -> None:
|
| 192 |
+
"""Test writing section with empty draft."""
|
| 193 |
+
long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
|
| 194 |
+
|
| 195 |
+
result = await long_writer_agent.write_next_section(
|
| 196 |
+
original_query="Test",
|
| 197 |
+
report_draft="",
|
| 198 |
+
next_section_title="Test",
|
| 199 |
+
next_section_draft="",
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
assert isinstance(result, LongWriterOutput)
|
| 203 |
+
|
| 204 |
+
@pytest.mark.asyncio
|
| 205 |
+
async def test_write_next_section_llm_failure(self, long_writer_agent: LongWriterAgent) -> None:
|
| 206 |
+
"""Test write_next_section handles LLM failures gracefully."""
|
| 207 |
+
long_writer_agent.agent.run = AsyncMock(side_effect=Exception("LLM error"))
|
| 208 |
+
|
| 209 |
+
result = await long_writer_agent.write_next_section(
|
| 210 |
+
original_query="Test",
|
| 211 |
+
report_draft="",
|
| 212 |
+
next_section_title="Test",
|
| 213 |
+
next_section_draft="Test",
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Should return fallback section
|
| 217 |
+
assert isinstance(result, LongWriterOutput)
|
| 218 |
+
assert "Test" in result.next_section_markdown
|
| 219 |
+
assert result.references == []
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class TestWriteReport:
|
| 223 |
+
"""Test write_report() method."""
|
| 224 |
+
|
| 225 |
+
@pytest.mark.asyncio
|
| 226 |
+
async def test_write_report_complete_flow(
|
| 227 |
+
self,
|
| 228 |
+
long_writer_agent: LongWriterAgent,
|
| 229 |
+
mock_agent_result: AgentResult[LongWriterOutput],
|
| 230 |
+
sample_report_draft: ReportDraft,
|
| 231 |
+
) -> None:
|
| 232 |
+
"""Test complete report writing flow."""
|
| 233 |
+
long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
|
| 234 |
+
|
| 235 |
+
original_query = "Test query"
|
| 236 |
+
report_title = "Test Report"
|
| 237 |
+
|
| 238 |
+
result = await long_writer_agent.write_report(
|
| 239 |
+
original_query=original_query,
|
| 240 |
+
report_title=report_title,
|
| 241 |
+
report_draft=sample_report_draft,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
assert isinstance(result, str)
|
| 245 |
+
assert report_title in result
|
| 246 |
+
assert "Table of Contents" in result
|
| 247 |
+
assert "Introduction" in result
|
| 248 |
+
assert "Methods" in result
|
| 249 |
+
# Should have called write_next_section for each section
|
| 250 |
+
assert long_writer_agent.agent.run.call_count == len(sample_report_draft.sections)
|
| 251 |
+
|
| 252 |
+
@pytest.mark.asyncio
|
| 253 |
+
async def test_write_report_single_section(
|
| 254 |
+
self,
|
| 255 |
+
long_writer_agent: LongWriterAgent,
|
| 256 |
+
mock_agent_result: AgentResult[LongWriterOutput],
|
| 257 |
+
) -> None:
|
| 258 |
+
"""Test writing report with single section."""
|
| 259 |
+
long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
|
| 260 |
+
|
| 261 |
+
report_draft = ReportDraft(
|
| 262 |
+
sections=[
|
| 263 |
+
ReportDraftSection(
|
| 264 |
+
section_title="Single Section",
|
| 265 |
+
section_content="Content",
|
| 266 |
+
)
|
| 267 |
+
]
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
result = await long_writer_agent.write_report(
|
| 271 |
+
original_query="Test",
|
| 272 |
+
report_title="Test Report",
|
| 273 |
+
report_draft=report_draft,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
assert isinstance(result, str)
|
| 277 |
+
assert "Single Section" in result
|
| 278 |
+
assert long_writer_agent.agent.run.call_count == 1
|
| 279 |
+
|
| 280 |
+
@pytest.mark.asyncio
|
| 281 |
+
async def test_write_report_multiple_sections(
|
| 282 |
+
self,
|
| 283 |
+
long_writer_agent: LongWriterAgent,
|
| 284 |
+
mock_agent_result: AgentResult[LongWriterOutput],
|
| 285 |
+
sample_report_draft: ReportDraft,
|
| 286 |
+
) -> None:
|
| 287 |
+
"""Test writing report with multiple sections."""
|
| 288 |
+
long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
|
| 289 |
+
|
| 290 |
+
result = await long_writer_agent.write_report(
|
| 291 |
+
original_query="Test",
|
| 292 |
+
report_title="Test Report",
|
| 293 |
+
report_draft=sample_report_draft,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
assert isinstance(result, str)
|
| 297 |
+
assert sample_report_draft.sections[0].section_title in result
|
| 298 |
+
assert sample_report_draft.sections[1].section_title in result
|
| 299 |
+
assert long_writer_agent.agent.run.call_count == len(sample_report_draft.sections)
|
| 300 |
+
|
| 301 |
+
@pytest.mark.asyncio
|
| 302 |
+
async def test_write_report_creates_table_of_contents(
|
| 303 |
+
self,
|
| 304 |
+
long_writer_agent: LongWriterAgent,
|
| 305 |
+
mock_agent_result: AgentResult[LongWriterOutput],
|
| 306 |
+
sample_report_draft: ReportDraft,
|
| 307 |
+
) -> None:
|
| 308 |
+
"""Test that write_report creates table of contents."""
|
| 309 |
+
long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
|
| 310 |
+
|
| 311 |
+
result = await long_writer_agent.write_report(
|
| 312 |
+
original_query="Test",
|
| 313 |
+
report_title="Test Report",
|
| 314 |
+
report_draft=sample_report_draft,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
assert "Table of Contents" in result
|
| 318 |
+
assert "1. Introduction" in result
|
| 319 |
+
assert "2. Methods" in result
|
| 320 |
+
|
| 321 |
+
@pytest.mark.asyncio
|
| 322 |
+
async def test_write_report_aggregates_references(
|
| 323 |
+
self,
|
| 324 |
+
long_writer_agent: LongWriterAgent,
|
| 325 |
+
sample_report_draft: ReportDraft,
|
| 326 |
+
) -> None:
|
| 327 |
+
"""Test that write_report aggregates references from all sections."""
|
| 328 |
+
# Create different outputs for each section
|
| 329 |
+
output1 = LongWriterOutput(
|
| 330 |
+
next_section_markdown="## Introduction\n\nContent [1].",
|
| 331 |
+
references=["[1] https://example.com/1"],
|
| 332 |
+
)
|
| 333 |
+
output2 = LongWriterOutput(
|
| 334 |
+
next_section_markdown="## Methods\n\nContent [1].",
|
| 335 |
+
references=["[1] https://example.com/2"],
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
results = [AgentResult(output=output1), AgentResult(output=output2)]
|
| 339 |
+
long_writer_agent.agent.run = AsyncMock(side_effect=results)
|
| 340 |
+
|
| 341 |
+
result = await long_writer_agent.write_report(
|
| 342 |
+
original_query="Test",
|
| 343 |
+
report_title="Test Report",
|
| 344 |
+
report_draft=sample_report_draft,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
assert "References:" in result
|
| 348 |
+
# Should have both references (reformatted)
|
| 349 |
+
assert "example.com/1" in result or "[1]" in result
|
| 350 |
+
assert "example.com/2" in result or "[2]" in result
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class TestReformatReferences:
|
| 354 |
+
"""Test _reformat_references() method."""
|
| 355 |
+
|
| 356 |
+
def test_reformat_references_deduplicates(self, long_writer_agent: LongWriterAgent) -> None:
|
| 357 |
+
"""Test that reference reformatting deduplicates URLs."""
|
| 358 |
+
section_markdown = "Content [1] and [2]."
|
| 359 |
+
section_references = [
|
| 360 |
+
"[1] https://example.com",
|
| 361 |
+
"[2] https://example.com", # Duplicate URL
|
| 362 |
+
]
|
| 363 |
+
all_references = []
|
| 364 |
+
|
| 365 |
+
updated_markdown, updated_refs = long_writer_agent._reformat_references(
|
| 366 |
+
section_markdown, section_references, all_references
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Should only have one reference
|
| 370 |
+
assert len(updated_refs) == 1
|
| 371 |
+
assert "example.com" in updated_refs[0]
|
| 372 |
+
|
| 373 |
+
def test_reformat_references_renumbers(self, long_writer_agent: LongWriterAgent) -> None:
|
| 374 |
+
"""Test that reference reformatting renumbers correctly."""
|
| 375 |
+
section_markdown = "Content [1] and [2]."
|
| 376 |
+
section_references = [
|
| 377 |
+
"[1] https://example.com/1",
|
| 378 |
+
"[2] https://example.com/2",
|
| 379 |
+
]
|
| 380 |
+
all_references = ["[1] https://example.com/0"] # Existing reference
|
| 381 |
+
|
| 382 |
+
updated_markdown, updated_refs = long_writer_agent._reformat_references(
|
| 383 |
+
section_markdown, section_references, all_references
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# Should have 3 references total (0, 1, 2)
|
| 387 |
+
assert len(updated_refs) == 3
|
| 388 |
+
# Markdown should have updated reference numbers
|
| 389 |
+
assert "[2]" in updated_markdown or "[3]" in updated_markdown
|
| 390 |
+
|
| 391 |
+
def test_reformat_references_handles_malformed(
|
| 392 |
+
self, long_writer_agent: LongWriterAgent
|
| 393 |
+
) -> None:
|
| 394 |
+
"""Test that reference reformatting handles malformed references."""
|
| 395 |
+
section_markdown = "Content [1]."
|
| 396 |
+
section_references = [
|
| 397 |
+
"[1] https://example.com",
|
| 398 |
+
"invalid reference", # Malformed
|
| 399 |
+
]
|
| 400 |
+
all_references = []
|
| 401 |
+
|
| 402 |
+
updated_markdown, updated_refs = long_writer_agent._reformat_references(
|
| 403 |
+
section_markdown, section_references, all_references
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# Should still work, just skip invalid references
|
| 407 |
+
assert isinstance(updated_markdown, str)
|
| 408 |
+
assert isinstance(updated_refs, list)
|
| 409 |
+
|
| 410 |
+
def test_reformat_references_empty_list(self, long_writer_agent: LongWriterAgent) -> None:
|
| 411 |
+
"""Test reference reformatting with empty reference list."""
|
| 412 |
+
section_markdown = "Content without citations."
|
| 413 |
+
section_references = []
|
| 414 |
+
all_references = []
|
| 415 |
+
|
| 416 |
+
updated_markdown, updated_refs = long_writer_agent._reformat_references(
|
| 417 |
+
section_markdown, section_references, all_references
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
assert updated_markdown == section_markdown
|
| 421 |
+
assert updated_refs == []
|
| 422 |
+
|
| 423 |
+
def test_reformat_references_preserves_markdown(
|
| 424 |
+
self, long_writer_agent: LongWriterAgent
|
| 425 |
+
) -> None:
|
| 426 |
+
"""Test that reference reformatting preserves markdown content."""
|
| 427 |
+
section_markdown = "## Section\n\nContent [1] with **bold** text."
|
| 428 |
+
section_references = ["[1] https://example.com"]
|
| 429 |
+
all_references = []
|
| 430 |
+
|
| 431 |
+
updated_markdown, _ = long_writer_agent._reformat_references(
|
| 432 |
+
section_markdown, section_references, all_references
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
assert "## Section" in updated_markdown
|
| 436 |
+
assert "**bold**" in updated_markdown
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
class TestReformatSectionHeadings:
|
| 440 |
+
"""Test _reformat_section_headings() method."""
|
| 441 |
+
|
| 442 |
+
def test_reformat_section_headings_level_2(self, long_writer_agent: LongWriterAgent) -> None:
|
| 443 |
+
"""Test that headings are reformatted to level 2."""
|
| 444 |
+
section_markdown = "## Section Title\n\nContent"
|
| 445 |
+
result = long_writer_agent._reformat_section_headings(section_markdown)
|
| 446 |
+
assert "## Section Title" in result
|
| 447 |
+
|
| 448 |
+
def test_reformat_section_headings_level_3(self, long_writer_agent: LongWriterAgent) -> None:
|
| 449 |
+
"""Test that level 3 headings are adjusted correctly."""
|
| 450 |
+
section_markdown = "### Section Title\n\nContent"
|
| 451 |
+
result = long_writer_agent._reformat_section_headings(section_markdown)
|
| 452 |
+
# Should be adjusted to level 2
|
| 453 |
+
assert "## Section Title" in result
|
| 454 |
+
|
| 455 |
+
def test_reformat_section_headings_no_headings(
|
| 456 |
+
self, long_writer_agent: LongWriterAgent
|
| 457 |
+
) -> None:
|
| 458 |
+
"""Test reformatting with no headings."""
|
| 459 |
+
section_markdown = "Just content without headings."
|
| 460 |
+
result = long_writer_agent._reformat_section_headings(section_markdown)
|
| 461 |
+
assert result == section_markdown
|
| 462 |
+
|
| 463 |
+
def test_reformat_section_headings_preserves_content(
|
| 464 |
+
self, long_writer_agent: LongWriterAgent
|
| 465 |
+
) -> None:
|
| 466 |
+
"""Test that content is preserved during heading reformatting."""
|
| 467 |
+
section_markdown = "# Section\n\nImportant content here."
|
| 468 |
+
result = long_writer_agent._reformat_section_headings(section_markdown)
|
| 469 |
+
assert "Important content here" in result
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class TestCreateLongWriterAgent:
|
| 473 |
+
"""Test create_long_writer_agent factory function."""
|
| 474 |
+
|
| 475 |
+
@patch("src.agents.long_writer.get_model")
|
| 476 |
+
@patch("src.agents.long_writer.LongWriterAgent")
|
| 477 |
+
def test_create_long_writer_agent_success(
|
| 478 |
+
self,
|
| 479 |
+
mock_long_writer_agent_class: MagicMock,
|
| 480 |
+
mock_get_model: MagicMock,
|
| 481 |
+
mock_model: MagicMock,
|
| 482 |
+
) -> None:
|
| 483 |
+
"""Test successful long writer agent creation."""
|
| 484 |
+
mock_get_model.return_value = mock_model
|
| 485 |
+
mock_agent_instance = MagicMock()
|
| 486 |
+
mock_long_writer_agent_class.return_value = mock_agent_instance
|
| 487 |
+
|
| 488 |
+
result = create_long_writer_agent()
|
| 489 |
+
|
| 490 |
+
assert result == mock_agent_instance
|
| 491 |
+
mock_long_writer_agent_class.assert_called_once_with(model=mock_model)
|
| 492 |
+
|
| 493 |
+
@patch("src.agents.long_writer.get_model")
|
| 494 |
+
@patch("src.agents.long_writer.LongWriterAgent")
|
| 495 |
+
def test_create_long_writer_agent_with_custom_model(
|
| 496 |
+
self,
|
| 497 |
+
mock_long_writer_agent_class: MagicMock,
|
| 498 |
+
mock_get_model: MagicMock,
|
| 499 |
+
mock_model: MagicMock,
|
| 500 |
+
) -> None:
|
| 501 |
+
"""Test long writer agent creation with custom model."""
|
| 502 |
+
mock_agent_instance = MagicMock()
|
| 503 |
+
mock_long_writer_agent_class.return_value = mock_agent_instance
|
| 504 |
+
|
| 505 |
+
result = create_long_writer_agent(model=mock_model)
|
| 506 |
+
|
| 507 |
+
assert result == mock_agent_instance
|
| 508 |
+
mock_long_writer_agent_class.assert_called_once_with(model=mock_model)
|
| 509 |
+
mock_get_model.assert_not_called()
|