Joseph Pollack commited on
Commit
731a241
·
unverified ·
1 Parent(s): 1515e72

adds the initial iterative and deep research workflows

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +3 -0
  2. .pre-commit-config.yaml +1 -0
  3. AGENTS.md +0 -118
  4. CLAUDE.md +0 -111
  5. GEMINI.md +0 -98
  6. docs/CONFIGURATION.md +291 -0
  7. docs/architecture/graph_orchestration.md +141 -0
  8. docs/examples/writer_agents_usage.md +415 -0
  9. docs/implementation/02_phase_search.md +31 -19
  10. pyproject.toml +11 -0
  11. src/agent_factory/agents.py +339 -0
  12. src/agent_factory/graph_builder.py +608 -0
  13. src/agent_factory/judges.py +9 -0
  14. src/agents/input_parser.py +178 -0
  15. src/agents/judge_agent.py +1 -1
  16. src/agents/knowledge_gap.py +156 -0
  17. src/agents/long_writer.py +431 -0
  18. src/agents/proofreader.py +205 -0
  19. src/agents/search_agent.py +1 -1
  20. src/agents/state.py +27 -5
  21. src/agents/thinking.py +148 -0
  22. src/agents/tool_selector.py +168 -0
  23. src/agents/writer.py +209 -0
  24. src/{orchestrator.py → legacy_orchestrator.py} +0 -0
  25. src/middleware/__init__.py +33 -0
  26. src/middleware/budget_tracker.py +390 -0
  27. src/middleware/state_machine.py +129 -0
  28. src/middleware/workflow_manager.py +322 -0
  29. src/orchestrator/__init__.py +48 -0
  30. src/orchestrator/graph_orchestrator.py +953 -0
  31. src/orchestrator/planner_agent.py +174 -0
  32. src/orchestrator/research_flow.py +999 -0
  33. src/orchestrator_factory.py +1 -1
  34. src/tools/__init__.py +8 -1
  35. src/tools/crawl_adapter.py +58 -0
  36. src/tools/rag_tool.py +183 -0
  37. src/tools/search_handler.py +67 -5
  38. src/tools/tool_executor.py +193 -0
  39. src/tools/web_search_adapter.py +63 -0
  40. src/utils/citation_validator.py +91 -0
  41. src/utils/config.py +98 -0
  42. src/utils/models.py +267 -1
  43. tests/integration/test_deep_research.py +352 -0
  44. tests/integration/test_middleware_integration.py +245 -0
  45. tests/integration/test_parallel_loops_judge.py +396 -0
  46. tests/integration/test_rag_integration.py +343 -0
  47. tests/integration/test_research_flows.py +584 -0
  48. tests/unit/agent_factory/test_graph_builder.py +439 -0
  49. tests/unit/agents/test_input_parser.py +325 -0
  50. tests/unit/agents/test_long_writer.py +509 -0
.gitignore CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  # Python
2
  __pycache__/
3
  *.py[cod]
 
1
+ folder/
2
+ .cursor/
3
+ .ruff_cache/
4
  # Python
5
  __pycache__/
6
  *.py[cod]
.pre-commit-config.yaml CHANGED
@@ -13,6 +13,7 @@ repos:
13
  hooks:
14
  - id: mypy
15
  files: ^src/
 
16
  additional_dependencies:
17
  - pydantic>=2.7
18
  - pydantic-settings>=2.2
 
13
  hooks:
14
  - id: mypy
15
  files: ^src/
16
+ exclude: ^folder
17
  additional_dependencies:
18
  - pydantic>=2.7
19
  - pydantic-settings>=2.2
AGENTS.md DELETED
@@ -1,118 +0,0 @@
1
- # AGENTS.md
2
-
3
- This file provides guidance to AI agents when working with code in this repository.
4
-
5
- ## Project Overview
6
-
7
- DeepCritical is an AI-native drug repurposing research agent for a HuggingFace hackathon. It uses a search-and-judge loop to autonomously search biomedical databases (PubMed, ClinicalTrials.gov, bioRxiv) and synthesize evidence for queries like "What existing drugs might help treat long COVID fatigue?".
8
-
9
- **Current Status:** Phases 1-13 COMPLETE (Foundation through Modal sandbox integration).
10
-
11
- ## Development Commands
12
-
13
- ```bash
14
- # Install all dependencies (including dev)
15
- make install # or: uv sync --all-extras && uv run pre-commit install
16
-
17
- # Run all quality checks (lint + typecheck + test) - MUST PASS BEFORE COMMIT
18
- make check
19
-
20
- # Individual commands
21
- make test # uv run pytest tests/unit/ -v
22
- make lint # uv run ruff check src tests
23
- make format # uv run ruff format src tests
24
- make typecheck # uv run mypy src
25
- make test-cov # uv run pytest --cov=src --cov-report=term-missing
26
-
27
- # Run single test
28
- uv run pytest tests/unit/utils/test_config.py::TestSettings::test_default_max_iterations -v
29
-
30
- # Integration tests (real APIs)
31
- uv run pytest -m integration
32
- ```
33
-
34
- ## Architecture
35
-
36
- **Pattern**: Search-and-judge loop with multi-tool orchestration.
37
-
38
- ```text
39
- User Question → Orchestrator
40
-
41
- Search Loop:
42
- 1. Query PubMed, ClinicalTrials.gov, bioRxiv
43
- 2. Gather evidence
44
- 3. Judge quality ("Do we have enough?")
45
- 4. If NO → Refine query, search more
46
- 5. If YES → Synthesize findings (+ optional Modal analysis)
47
-
48
- Research Report with Citations
49
- ```
50
-
51
- **Key Components**:
52
-
53
- - `src/orchestrator.py` - Main agent loop
54
- - `src/tools/pubmed.py` - PubMed E-utilities search
55
- - `src/tools/clinicaltrials.py` - ClinicalTrials.gov API
56
- - `src/tools/biorxiv.py` - bioRxiv/medRxiv preprint search
57
- - `src/tools/code_execution.py` - Modal sandbox execution
58
- - `src/tools/search_handler.py` - Scatter-gather orchestration
59
- - `src/services/embeddings.py` - Semantic search & deduplication (ChromaDB)
60
- - `src/services/statistical_analyzer.py` - Statistical analysis via Modal
61
- - `src/agent_factory/judges.py` - LLM-based evidence assessment
62
- - `src/agents/` - Magentic multi-agent mode (SearchAgent, JudgeAgent, etc.)
63
- - `src/mcp_tools.py` - MCP tool wrappers for Claude Desktop
64
- - `src/utils/config.py` - Pydantic Settings (loads from `.env`)
65
- - `src/utils/models.py` - Evidence, Citation, SearchResult models
66
- - `src/utils/exceptions.py` - Exception hierarchy
67
- - `src/app.py` - Gradio UI with MCP server (HuggingFace Spaces)
68
-
69
- **Break Conditions**: Judge approval, token budget (50K max), or max iterations (default 10).
70
-
71
- ## Configuration
72
-
73
- Settings via pydantic-settings from `.env`:
74
-
75
- - `LLM_PROVIDER`: "openai" or "anthropic"
76
- - `OPENAI_API_KEY` / `ANTHROPIC_API_KEY`: LLM keys
77
- - `NCBI_API_KEY`: Optional, for higher PubMed rate limits
78
- - `MODAL_TOKEN_ID` / `MODAL_TOKEN_SECRET`: For Modal sandbox (optional)
79
- - `MAX_ITERATIONS`: 1-50, default 10
80
- - `LOG_LEVEL`: DEBUG, INFO, WARNING, ERROR
81
-
82
- ## Exception Hierarchy
83
-
84
- ```text
85
- DeepCriticalError (base)
86
- ├── SearchError
87
- │ └── RateLimitError
88
- ├── JudgeError
89
- └── ConfigurationError
90
- ```
91
-
92
- ## Testing
93
-
94
- - **TDD**: Write tests first in `tests/unit/`, implement in `src/`
95
- - **Markers**: `unit`, `integration`, `slow`
96
- - **Mocking**: `respx` for httpx, `pytest-mock` for general mocking
97
- - **Fixtures**: `tests/conftest.py` has `mock_httpx_client`, `mock_llm_response`
98
-
99
- ## Coding Standards
100
-
101
- - Python 3.11+, strict mypy, ruff (100-char lines)
102
- - Type all functions, use Pydantic models for data
103
- - Use `structlog` for logging, not print
104
- - Conventional commits: `feat(scope):`, `fix:`, `docs:`
105
-
106
- ## Git Workflow
107
-
108
- - `main`: Production-ready (GitHub)
109
- - `dev`: Development integration (GitHub)
110
- - Remote `origin`: GitHub (source of truth for PRs/code review)
111
- - Remote `huggingface-upstream`: HuggingFace Spaces (deployment target)
112
-
113
- **HuggingFace Spaces Collaboration:**
114
-
115
- - Each contributor should use their own dev branch: `yourname-dev` (e.g., `vcms-dev`, `mario-dev`)
116
- - **DO NOT push directly to `main` or `dev` on HuggingFace** - these can be overwritten easily
117
- - GitHub is the source of truth; HuggingFace is for deployment/demo
118
- - Consider using git hooks to prevent accidental pushes to protected branches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
CLAUDE.md DELETED
@@ -1,111 +0,0 @@
1
- # CLAUDE.md
2
-
3
- This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
-
5
- ## Project Overview
6
-
7
- DeepCritical is an AI-native drug repurposing research agent for a HuggingFace hackathon. It uses a search-and-judge loop to autonomously search biomedical databases (PubMed, ClinicalTrials.gov, bioRxiv) and synthesize evidence for queries like "What existing drugs might help treat long COVID fatigue?".
8
-
9
- **Current Status:** Phases 1-13 COMPLETE (Foundation through Modal sandbox integration).
10
-
11
- ## Development Commands
12
-
13
- ```bash
14
- # Install all dependencies (including dev)
15
- make install # or: uv sync --all-extras && uv run pre-commit install
16
-
17
- # Run all quality checks (lint + typecheck + test) - MUST PASS BEFORE COMMIT
18
- make check
19
-
20
- # Individual commands
21
- make test # uv run pytest tests/unit/ -v
22
- make lint # uv run ruff check src tests
23
- make format # uv run ruff format src tests
24
- make typecheck # uv run mypy src
25
- make test-cov # uv run pytest --cov=src --cov-report=term-missing
26
-
27
- # Run single test
28
- uv run pytest tests/unit/utils/test_config.py::TestSettings::test_default_max_iterations -v
29
-
30
- # Integration tests (real APIs)
31
- uv run pytest -m integration
32
- ```
33
-
34
- ## Architecture
35
-
36
- **Pattern**: Search-and-judge loop with multi-tool orchestration.
37
-
38
- ```text
39
- User Question → Orchestrator
40
-
41
- Search Loop:
42
- 1. Query PubMed, ClinicalTrials.gov, bioRxiv
43
- 2. Gather evidence
44
- 3. Judge quality ("Do we have enough?")
45
- 4. If NO → Refine query, search more
46
- 5. If YES → Synthesize findings (+ optional Modal analysis)
47
-
48
- Research Report with Citations
49
- ```
50
-
51
- **Key Components**:
52
-
53
- - `src/orchestrator.py` - Main agent loop
54
- - `src/tools/pubmed.py` - PubMed E-utilities search
55
- - `src/tools/clinicaltrials.py` - ClinicalTrials.gov API
56
- - `src/tools/biorxiv.py` - bioRxiv/medRxiv preprint search
57
- - `src/tools/code_execution.py` - Modal sandbox execution
58
- - `src/tools/search_handler.py` - Scatter-gather orchestration
59
- - `src/services/embeddings.py` - Semantic search & deduplication (ChromaDB)
60
- - `src/services/statistical_analyzer.py` - Statistical analysis via Modal
61
- - `src/agent_factory/judges.py` - LLM-based evidence assessment
62
- - `src/agents/` - Magentic multi-agent mode (SearchAgent, JudgeAgent, etc.)
63
- - `src/mcp_tools.py` - MCP tool wrappers for Claude Desktop
64
- - `src/utils/config.py` - Pydantic Settings (loads from `.env`)
65
- - `src/utils/models.py` - Evidence, Citation, SearchResult models
66
- - `src/utils/exceptions.py` - Exception hierarchy
67
- - `src/app.py` - Gradio UI with MCP server (HuggingFace Spaces)
68
-
69
- **Break Conditions**: Judge approval, token budget (50K max), or max iterations (default 10).
70
-
71
- ## Configuration
72
-
73
- Settings via pydantic-settings from `.env`:
74
-
75
- - `LLM_PROVIDER`: "openai" or "anthropic"
76
- - `OPENAI_API_KEY` / `ANTHROPIC_API_KEY`: LLM keys
77
- - `NCBI_API_KEY`: Optional, for higher PubMed rate limits
78
- - `MODAL_TOKEN_ID` / `MODAL_TOKEN_SECRET`: For Modal sandbox (optional)
79
- - `MAX_ITERATIONS`: 1-50, default 10
80
- - `LOG_LEVEL`: DEBUG, INFO, WARNING, ERROR
81
-
82
- ## Exception Hierarchy
83
-
84
- ```text
85
- DeepCriticalError (base)
86
- ├── SearchError
87
- │ └── RateLimitError
88
- ├── JudgeError
89
- └── ConfigurationError
90
- ```
91
-
92
- ## Testing
93
-
94
- - **TDD**: Write tests first in `tests/unit/`, implement in `src/`
95
- - **Markers**: `unit`, `integration`, `slow`
96
- - **Mocking**: `respx` for httpx, `pytest-mock` for general mocking
97
- - **Fixtures**: `tests/conftest.py` has `mock_httpx_client`, `mock_llm_response`
98
-
99
- ## Git Workflow
100
-
101
- - `main`: Production-ready (GitHub)
102
- - `dev`: Development integration (GitHub)
103
- - Remote `origin`: GitHub (source of truth for PRs/code review)
104
- - Remote `huggingface-upstream`: HuggingFace Spaces (deployment target)
105
-
106
- **HuggingFace Spaces Collaboration:**
107
-
108
- - Each contributor should use their own dev branch: `yourname-dev` (e.g., `vcms-dev`, `mario-dev`)
109
- - **DO NOT push directly to `main` or `dev` on HuggingFace** - these can be overwritten easily
110
- - GitHub is the source of truth; HuggingFace is for deployment/demo
111
- - Consider using git hooks to prevent accidental pushes to protected branches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GEMINI.md DELETED
@@ -1,98 +0,0 @@
1
- # DeepCritical Context
2
-
3
- ## Project Overview
4
-
5
- **DeepCritical** is an AI-native Medical Drug Repurposing Research Agent.
6
- **Goal:** To accelerate the discovery of new uses for existing drugs by intelligently searching biomedical literature (PubMed, ClinicalTrials.gov, bioRxiv), evaluating evidence, and hypothesizing potential applications.
7
-
8
- **Architecture:**
9
- The project follows a **Vertical Slice Architecture** (Search -> Judge -> Orchestrator) and adheres to **Strict TDD** (Test-Driven Development).
10
-
11
- **Current Status:**
12
-
13
- - **Phases 1-9:** COMPLETE. Foundation, Search, Judge, UI, Orchestrator, Embeddings, Hypothesis, Report, Cleanup.
14
- - **Phases 10-11:** COMPLETE. ClinicalTrials.gov and bioRxiv integration.
15
- - **Phase 12:** COMPLETE. MCP Server integration (Gradio MCP at `/gradio_api/mcp/`).
16
- - **Phase 13:** COMPLETE. Modal sandbox for statistical analysis.
17
-
18
- ## Tech Stack & Tooling
19
-
20
- - **Language:** Python 3.11 (Pinned)
21
- - **Package Manager:** `uv` (Rust-based, extremely fast)
22
- - **Frameworks:** `pydantic`, `pydantic-ai`, `httpx`, `gradio[mcp]`
23
- - **Vector DB:** `chromadb` with `sentence-transformers` for semantic search
24
- - **Code Execution:** `modal` for secure sandboxed Python execution
25
- - **Testing:** `pytest`, `pytest-asyncio`, `respx` (for mocking)
26
- - **Quality:** `ruff` (linting/formatting), `mypy` (strict type checking), `pre-commit`
27
-
28
- ## Building & Running
29
-
30
- | Command | Description |
31
- | :--- | :--- |
32
- | `make install` | Install dependencies and pre-commit hooks. |
33
- | `make test` | Run unit tests. |
34
- | `make lint` | Run Ruff linter. |
35
- | `make format` | Run Ruff formatter. |
36
- | `make typecheck` | Run Mypy static type checker. |
37
- | `make check` | **The Golden Gate:** Runs lint, typecheck, and test. Must pass before committing. |
38
- | `make clean` | Clean up cache and artifacts. |
39
-
40
- ## Directory Structure
41
-
42
- - `src/`: Source code
43
- - `utils/`: Shared utilities (`config.py`, `exceptions.py`, `models.py`)
44
- - `tools/`: Search tools (`pubmed.py`, `clinicaltrials.py`, `biorxiv.py`, `code_execution.py`)
45
- - `services/`: Services (`embeddings.py`, `statistical_analyzer.py`)
46
- - `agents/`: Magentic multi-agent mode agents
47
- - `agent_factory/`: Agent definitions (judges, prompts)
48
- - `mcp_tools.py`: MCP tool wrappers for Claude Desktop integration
49
- - `app.py`: Gradio UI with MCP server
50
- - `tests/`: Test suite
51
- - `unit/`: Isolated unit tests (Mocked)
52
- - `integration/`: Real API tests (Marked as slow/integration)
53
- - `docs/`: Documentation and Implementation Specs
54
- - `examples/`: Working demos for each phase
55
-
56
- ## Key Components
57
-
58
- - `src/orchestrator.py` - Main agent loop
59
- - `src/tools/pubmed.py` - PubMed E-utilities search
60
- - `src/tools/clinicaltrials.py` - ClinicalTrials.gov API
61
- - `src/tools/biorxiv.py` - bioRxiv/medRxiv preprint search
62
- - `src/tools/code_execution.py` - Modal sandbox execution
63
- - `src/services/statistical_analyzer.py` - Statistical analysis via Modal
64
- - `src/mcp_tools.py` - MCP tool wrappers
65
- - `src/app.py` - Gradio UI (HuggingFace Spaces) with MCP server
66
-
67
- ## Configuration
68
-
69
- Settings via pydantic-settings from `.env`:
70
-
71
- - `LLM_PROVIDER`: "openai" or "anthropic"
72
- - `OPENAI_API_KEY` / `ANTHROPIC_API_KEY`: LLM keys
73
- - `NCBI_API_KEY`: Optional, for higher PubMed rate limits
74
- - `MODAL_TOKEN_ID` / `MODAL_TOKEN_SECRET`: For Modal sandbox (optional)
75
- - `MAX_ITERATIONS`: 1-50, default 10
76
- - `LOG_LEVEL`: DEBUG, INFO, WARNING, ERROR
77
-
78
- ## Development Conventions
79
-
80
- 1. **Strict TDD:** Write failing tests in `tests/unit/` *before* implementing logic in `src/`.
81
- 2. **Type Safety:** All code must pass `mypy --strict`. Use Pydantic models for data exchange.
82
- 3. **Linting:** Zero tolerance for Ruff errors.
83
- 4. **Mocking:** Use `respx` or `unittest.mock` for all external API calls in unit tests.
84
- 5. **Vertical Slices:** Implement features end-to-end rather than layer-by-layer.
85
-
86
- ## Git Workflow
87
-
88
- - `main`: Production-ready (GitHub)
89
- - `dev`: Development integration (GitHub)
90
- - Remote `origin`: GitHub (source of truth for PRs/code review)
91
- - Remote `huggingface-upstream`: HuggingFace Spaces (deployment target)
92
-
93
- **HuggingFace Spaces Collaboration:**
94
-
95
- - Each contributor should use their own dev branch: `yourname-dev` (e.g., `vcms-dev`, `mario-dev`)
96
- - **DO NOT push directly to `main` or `dev` on HuggingFace** - these can be overwritten easily
97
- - GitHub is the source of truth; HuggingFace is for deployment/demo
98
- - Consider using git hooks to prevent accidental pushes to protected branches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/CONFIGURATION.md ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration Guide
2
+
3
+ ## Overview
4
+
5
+ DeepCritical uses **Pydantic Settings** for centralized configuration management. All settings are defined in `src/utils/config.py` and can be configured via environment variables or a `.env` file.
6
+
7
+ ## Quick Start
8
+
9
+ 1. Copy the example environment file (if available) or create a `.env` file in the project root
10
+ 2. Set at least one LLM API key (`OPENAI_API_KEY` or `ANTHROPIC_API_KEY`)
11
+ 3. Optionally configure other services as needed
12
+
13
+ ## Configuration System
14
+
15
+ ### How It Works
16
+
17
+ - **Settings Class**: `Settings` class in `src/utils/config.py` extends `BaseSettings` from `pydantic_settings`
18
+ - **Environment File**: Automatically loads from `.env` file (if present)
19
+ - **Environment Variables**: Reads from environment variables (case-insensitive)
20
+ - **Type Safety**: Strongly-typed fields with validation
21
+ - **Singleton Pattern**: Global `settings` instance for easy access
22
+
23
+ ### Usage
24
+
25
+ ```python
26
+ from src.utils.config import settings
27
+
28
+ # Check if API keys are available
29
+ if settings.has_openai_key:
30
+ # Use OpenAI
31
+ pass
32
+
33
+ # Access configuration values
34
+ max_iterations = settings.max_iterations
35
+ web_search_provider = settings.web_search_provider
36
+ ```
37
+
38
+ ## Required Configuration
39
+
40
+ ### At Least One LLM Provider
41
+
42
+ You must configure at least one LLM provider:
43
+
44
+ **OpenAI:**
45
+ ```bash
46
+ LLM_PROVIDER=openai
47
+ OPENAI_API_KEY=your_openai_api_key_here
48
+ OPENAI_MODEL=gpt-5.1
49
+ ```
50
+
51
+ **Anthropic:**
52
+ ```bash
53
+ LLM_PROVIDER=anthropic
54
+ ANTHROPIC_API_KEY=your_anthropic_api_key_here
55
+ ANTHROPIC_MODEL=claude-sonnet-4-5-20250929
56
+ ```
57
+
58
+ ## Optional Configuration
59
+
60
+ ### Embedding Configuration
61
+
62
+ ```bash
63
+ # Embedding Provider: "openai", "local", or "huggingface"
64
+ EMBEDDING_PROVIDER=local
65
+
66
+ # OpenAI Embedding Model (used by LlamaIndex RAG)
67
+ OPENAI_EMBEDDING_MODEL=text-embedding-3-small
68
+
69
+ # Local Embedding Model (sentence-transformers)
70
+ LOCAL_EMBEDDING_MODEL=all-MiniLM-L6-v2
71
+
72
+ # HuggingFace Embedding Model
73
+ HUGGINGFACE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
74
+ ```
75
+
76
+ ### HuggingFace Configuration
77
+
78
+ ```bash
79
+ # HuggingFace API Token (for inference API)
80
+ HUGGINGFACE_API_KEY=your_huggingface_api_key_here
81
+ # Or use HF_TOKEN (alternative name)
82
+
83
+ # Default HuggingFace Model ID
84
+ HUGGINGFACE_MODEL=meta-llama/Llama-3.1-8B-Instruct
85
+ ```
86
+
87
+ ### Web Search Configuration
88
+
89
+ ```bash
90
+ # Web Search Provider: "serper", "searchxng", "brave", "tavily", or "duckduckgo"
91
+ # Default: "duckduckgo" (no API key required)
92
+ WEB_SEARCH_PROVIDER=duckduckgo
93
+
94
+ # Serper API Key (for Google search via Serper)
95
+ SERPER_API_KEY=your_serper_api_key_here
96
+
97
+ # SearchXNG Host URL
98
+ SEARCHXNG_HOST=http://localhost:8080
99
+
100
+ # Brave Search API Key
101
+ BRAVE_API_KEY=your_brave_api_key_here
102
+
103
+ # Tavily API Key
104
+ TAVILY_API_KEY=your_tavily_api_key_here
105
+ ```
106
+
107
+ ### PubMed Configuration
108
+
109
+ ```bash
110
+ # NCBI API Key (optional, for higher rate limits: 10 req/sec vs 3 req/sec)
111
+ NCBI_API_KEY=your_ncbi_api_key_here
112
+ ```
113
+
114
+ ### Agent Configuration
115
+
116
+ ```bash
117
+ # Maximum iterations per research loop
118
+ MAX_ITERATIONS=10
119
+
120
+ # Search timeout in seconds
121
+ SEARCH_TIMEOUT=30
122
+
123
+ # Use graph-based execution for research flows
124
+ USE_GRAPH_EXECUTION=false
125
+ ```
126
+
127
+ ### Budget & Rate Limiting Configuration
128
+
129
+ ```bash
130
+ # Default token budget per research loop
131
+ DEFAULT_TOKEN_LIMIT=100000
132
+
133
+ # Default time limit per research loop (minutes)
134
+ DEFAULT_TIME_LIMIT_MINUTES=10
135
+
136
+ # Default iterations limit per research loop
137
+ DEFAULT_ITERATIONS_LIMIT=10
138
+ ```
139
+
140
+ ### RAG Service Configuration
141
+
142
+ ```bash
143
+ # ChromaDB collection name for RAG
144
+ RAG_COLLECTION_NAME=deepcritical_evidence
145
+
146
+ # Number of top results to retrieve from RAG
147
+ RAG_SIMILARITY_TOP_K=5
148
+
149
+ # Automatically ingest evidence into RAG
150
+ RAG_AUTO_INGEST=true
151
+ ```
152
+
153
+ ### ChromaDB Configuration
154
+
155
+ ```bash
156
+ # ChromaDB storage path
157
+ CHROMA_DB_PATH=./chroma_db
158
+
159
+ # Whether to persist ChromaDB to disk
160
+ CHROMA_DB_PERSIST=true
161
+
162
+ # ChromaDB server host (for remote ChromaDB, optional)
163
+ # CHROMA_DB_HOST=localhost
164
+
165
+ # ChromaDB server port (for remote ChromaDB, optional)
166
+ # CHROMA_DB_PORT=8000
167
+ ```
168
+
169
+ ### External Services
170
+
171
+ ```bash
172
+ # Modal Token ID (for Modal sandbox execution)
173
+ MODAL_TOKEN_ID=your_modal_token_id_here
174
+
175
+ # Modal Token Secret
176
+ MODAL_TOKEN_SECRET=your_modal_token_secret_here
177
+ ```
178
+
179
+ ### Logging Configuration
180
+
181
+ ```bash
182
+ # Log Level: "DEBUG", "INFO", "WARNING", or "ERROR"
183
+ LOG_LEVEL=INFO
184
+ ```
185
+
186
+ ## Configuration Properties
187
+
188
+ The `Settings` class provides helpful properties for checking configuration:
189
+
190
+ ```python
191
+ from src.utils.config import settings
192
+
193
+ # Check API key availability
194
+ settings.has_openai_key # bool
195
+ settings.has_anthropic_key # bool
196
+ settings.has_huggingface_key # bool
197
+ settings.has_any_llm_key # bool
198
+
199
+ # Check service availability
200
+ settings.modal_available # bool
201
+ settings.web_search_available # bool
202
+ ```
203
+
204
+ ## Environment Variables Reference
205
+
206
+ ### Required (at least one LLM)
207
+ - `OPENAI_API_KEY` or `ANTHROPIC_API_KEY` - At least one LLM provider key
208
+
209
+ ### Optional LLM Providers
210
+ - `DEEPSEEK_API_KEY` (Phase 2)
211
+ - `OPENROUTER_API_KEY` (Phase 2)
212
+ - `GEMINI_API_KEY` (Phase 2)
213
+ - `PERPLEXITY_API_KEY` (Phase 2)
214
+ - `HUGGINGFACE_API_KEY` or `HF_TOKEN`
215
+ - `AZURE_OPENAI_ENDPOINT` (Phase 2)
216
+ - `AZURE_OPENAI_DEPLOYMENT` (Phase 2)
217
+ - `AZURE_OPENAI_API_KEY` (Phase 2)
218
+ - `AZURE_OPENAI_API_VERSION` (Phase 2)
219
+ - `LOCAL_MODEL_URL` (Phase 2)
220
+
221
+ ### Web Search
222
+ - `WEB_SEARCH_PROVIDER` (default: "duckduckgo")
223
+ - `SERPER_API_KEY`
224
+ - `SEARCHXNG_HOST`
225
+ - `BRAVE_API_KEY`
226
+ - `TAVILY_API_KEY`
227
+
228
+ ### Embeddings
229
+ - `EMBEDDING_PROVIDER` (default: "local")
230
+ - `HUGGINGFACE_EMBEDDING_MODEL` (optional)
231
+
232
+ ### RAG
233
+ - `RAG_COLLECTION_NAME` (default: "deepcritical_evidence")
234
+ - `RAG_SIMILARITY_TOP_K` (default: 5)
235
+ - `RAG_AUTO_INGEST` (default: true)
236
+
237
+ ### ChromaDB
238
+ - `CHROMA_DB_PATH` (default: "./chroma_db")
239
+ - `CHROMA_DB_PERSIST` (default: true)
240
+ - `CHROMA_DB_HOST` (optional)
241
+ - `CHROMA_DB_PORT` (optional)
242
+
243
+ ### Budget
244
+ - `DEFAULT_TOKEN_LIMIT` (default: 100000)
245
+ - `DEFAULT_TIME_LIMIT_MINUTES` (default: 10)
246
+ - `DEFAULT_ITERATIONS_LIMIT` (default: 10)
247
+
248
+ ### Other
249
+ - `LLM_PROVIDER` (default: "openai")
250
+ - `NCBI_API_KEY` (optional)
251
+ - `MODAL_TOKEN_ID` (optional)
252
+ - `MODAL_TOKEN_SECRET` (optional)
253
+ - `MAX_ITERATIONS` (default: 10)
254
+ - `LOG_LEVEL` (default: "INFO")
255
+ - `USE_GRAPH_EXECUTION` (default: false)
256
+
257
+ ## Validation
258
+
259
+ Settings are validated on load using Pydantic validation:
260
+
261
+ - **Type checking**: All fields are strongly typed
262
+ - **Range validation**: Numeric fields have min/max constraints
263
+ - **Literal validation**: Enum fields only accept specific values
264
+ - **Required fields**: API keys are checked when accessed via `get_api_key()`
265
+
266
+ ## Error Handling
267
+
268
+ Configuration errors raise `ConfigurationError`:
269
+
270
+ ```python
271
+ from src.utils.config import settings
272
+ from src.utils.exceptions import ConfigurationError
273
+
274
+ try:
275
+ api_key = settings.get_api_key()
276
+ except ConfigurationError as e:
277
+ print(f"Configuration error: {e}")
278
+ ```
279
+
280
+ ## Future Enhancements (Phase 2)
281
+
282
+ The following configurations are planned for Phase 2:
283
+
284
+ 1. **Additional LLM Providers**: DeepSeek, OpenRouter, Gemini, Perplexity, Azure OpenAI, Local models
285
+ 2. **Model Selection**: Reasoning/main/fast model configuration
286
+ 3. **Service Integration**: Migrate `folder/llm_config.py` to centralized config
287
+
288
+ See `CONFIGURATION_ANALYSIS.md` for the complete implementation plan.
289
+
290
+
291
+
docs/architecture/graph_orchestration.md ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Graph Orchestration Architecture
2
+
3
+ ## Overview
4
+
5
+ Phase 4 implements a graph-based orchestration system for research workflows using Pydantic AI agents as nodes. This enables better parallel execution, conditional routing, and state management compared to simple agent chains.
6
+
7
+ ## Graph Structure
8
+
9
+ ### Nodes
10
+
11
+ Graph nodes represent different stages in the research workflow:
12
+
13
+ 1. **Agent Nodes**: Execute Pydantic AI agents
14
+ - Input: Prompt/query
15
+ - Output: Structured or unstructured response
16
+ - Examples: `KnowledgeGapAgent`, `ToolSelectorAgent`, `ThinkingAgent`
17
+
18
+ 2. **State Nodes**: Update or read workflow state
19
+ - Input: Current state
20
+ - Output: Updated state
21
+ - Examples: Update evidence, update conversation history
22
+
23
+ 3. **Decision Nodes**: Make routing decisions based on conditions
24
+ - Input: Current state/results
25
+ - Output: Next node ID
26
+ - Examples: Continue research vs. complete research
27
+
28
+ 4. **Parallel Nodes**: Execute multiple nodes concurrently
29
+ - Input: List of node IDs
30
+ - Output: Aggregated results
31
+ - Examples: Parallel iterative research loops
32
+
33
+ ### Edges
34
+
35
+ Edges define transitions between nodes:
36
+
37
+ 1. **Sequential Edges**: Always traversed (no condition)
38
+ - From: Source node
39
+ - To: Target node
40
+ - Condition: None (always True)
41
+
42
+ 2. **Conditional Edges**: Traversed based on condition
43
+ - From: Source node
44
+ - To: Target node
45
+ - Condition: Callable that returns bool
46
+ - Example: If research complete → go to writer, else → continue loop
47
+
48
+ 3. **Parallel Edges**: Used for parallel execution branches
49
+ - From: Parallel node
50
+ - To: Multiple target nodes
51
+ - Execution: All targets run concurrently
52
+
53
+ ## Graph Patterns
54
+
55
+ ### Iterative Research Graph
56
+
57
+ ```
58
+ [Input] → [Thinking] → [Knowledge Gap] → [Decision: Complete?]
59
+ ↓ No ↓ Yes
60
+ [Tool Selector] [Writer]
61
+
62
+ [Execute Tools] → [Loop Back]
63
+ ```
64
+
65
+ ### Deep Research Graph
66
+
67
+ ```
68
+ [Input] → [Planner] → [Parallel Iterative Loops] → [Synthesizer]
69
+ ↓ ↓ ↓
70
+ [Loop1] [Loop2] [Loop3]
71
+ ```
72
+
73
+ ## State Management
74
+
75
+ State is managed via `WorkflowState` using `ContextVar` for thread-safe isolation:
76
+
77
+ - **Evidence**: Collected evidence from searches
78
+ - **Conversation**: Iteration history (gaps, tool calls, findings, thoughts)
79
+ - **Embedding Service**: For semantic search
80
+
81
+ State transitions occur at state nodes, which update the global workflow state.
82
+
83
+ ## Execution Flow
84
+
85
+ 1. **Graph Construction**: Build graph from nodes and edges
86
+ 2. **Graph Validation**: Ensure graph is valid (no cycles, all nodes reachable)
87
+ 3. **Graph Execution**: Traverse graph from entry node
88
+ 4. **Node Execution**: Execute each node based on type
89
+ 5. **Edge Evaluation**: Determine next node(s) based on edges
90
+ 6. **Parallel Execution**: Use `asyncio.gather()` for parallel nodes
91
+ 7. **State Updates**: Update state at state nodes
92
+ 8. **Event Streaming**: Yield events during execution for UI
93
+
94
+ ## Conditional Routing
95
+
96
+ Decision nodes evaluate conditions and return next node IDs:
97
+
98
+ - **Knowledge Gap Decision**: If `research_complete` → writer, else → tool selector
99
+ - **Budget Decision**: If budget exceeded → exit, else → continue
100
+ - **Iteration Decision**: If max iterations → exit, else → continue
101
+
102
+ ## Parallel Execution
103
+
104
+ Parallel nodes execute multiple nodes concurrently:
105
+
106
+ - Each parallel branch runs independently
107
+ - Results are aggregated after all branches complete
108
+ - State is synchronized after parallel execution
109
+ - Errors in one branch don't stop other branches
110
+
111
+ ## Budget Enforcement
112
+
113
+ Budget constraints are enforced at decision nodes:
114
+
115
+ - **Token Budget**: Track LLM token usage
116
+ - **Time Budget**: Track elapsed time
117
+ - **Iteration Budget**: Track iteration count
118
+
119
+ If any budget is exceeded, execution routes to exit node.
120
+
121
+ ## Error Handling
122
+
123
+ Errors are handled at multiple levels:
124
+
125
+ 1. **Node Level**: Catch errors in individual node execution
126
+ 2. **Graph Level**: Handle errors during graph traversal
127
+ 3. **State Level**: Rollback state changes on error
128
+
129
+ Errors are logged and yield error events for UI.
130
+
131
+ ## Backward Compatibility
132
+
133
+ Graph execution is optional via feature flag:
134
+
135
+ - `USE_GRAPH_EXECUTION=true`: Use graph-based execution
136
+ - `USE_GRAPH_EXECUTION=false`: Use agent chain execution (existing)
137
+
138
+ This allows gradual migration and fallback if needed.
139
+
140
+
141
+
docs/examples/writer_agents_usage.md ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Writer Agents Usage Examples
2
+
3
+ This document provides examples of how to use the writer agents in DeepCritical for generating research reports.
4
+
5
+ ## Overview
6
+
7
+ DeepCritical provides three writer agents for different report generation scenarios:
8
+
9
+ 1. **WriterAgent** - Basic writer for simple reports from findings
10
+ 2. **LongWriterAgent** - Iterative writer for long-form multi-section reports
11
+ 3. **ProofreaderAgent** - Finalizes and polishes report drafts
12
+
13
+ ## WriterAgent
14
+
15
+ The `WriterAgent` generates final reports from research findings. It's used in iterative research flows.
16
+
17
+ ### Basic Usage
18
+
19
+ ```python
20
+ from src.agent_factory.agents import create_writer_agent
21
+
22
+ # Create writer agent
23
+ writer = create_writer_agent()
24
+
25
+ # Generate report
26
+ query = "What is the capital of France?"
27
+ findings = """
28
+ Paris is the capital of France [1].
29
+ It is located in the north-central part of the country [2].
30
+
31
+ [1] https://example.com/france-info
32
+ [2] https://example.com/paris-info
33
+ """
34
+
35
+ report = await writer.write_report(
36
+ query=query,
37
+ findings=findings,
38
+ )
39
+
40
+ print(report)
41
+ ```
42
+
43
+ ### With Output Length Specification
44
+
45
+ ```python
46
+ report = await writer.write_report(
47
+ query="Explain machine learning",
48
+ findings=findings,
49
+ output_length="500 words",
50
+ )
51
+ ```
52
+
53
+ ### With Additional Instructions
54
+
55
+ ```python
56
+ report = await writer.write_report(
57
+ query="Explain machine learning",
58
+ findings=findings,
59
+ output_length="A comprehensive overview",
60
+ output_instructions="Use formal academic language and include examples",
61
+ )
62
+ ```
63
+
64
+ ### Integration with IterativeResearchFlow
65
+
66
+ The `WriterAgent` is automatically used by `IterativeResearchFlow`:
67
+
68
+ ```python
69
+ from src.agent_factory.agents import create_iterative_flow
70
+
71
+ flow = create_iterative_flow(max_iterations=5, max_time_minutes=10)
72
+ report = await flow.run(
73
+ query="What is quantum computing?",
74
+ output_length="A detailed explanation",
75
+ output_instructions="Include practical applications",
76
+ )
77
+ ```
78
+
79
+ ## LongWriterAgent
80
+
81
+ The `LongWriterAgent` iteratively writes report sections with proper citation management. It's used in deep research flows.
82
+
83
+ ### Basic Usage
84
+
85
+ ```python
86
+ from src.agent_factory.agents import create_long_writer_agent
87
+ from src.utils.models import ReportDraft, ReportDraftSection
88
+
89
+ # Create long writer agent
90
+ long_writer = create_long_writer_agent()
91
+
92
+ # Create report draft with sections
93
+ report_draft = ReportDraft(
94
+ sections=[
95
+ ReportDraftSection(
96
+ section_title="Introduction",
97
+ section_content="Draft content for introduction with [1].",
98
+ ),
99
+ ReportDraftSection(
100
+ section_title="Methods",
101
+ section_content="Draft content for methods with [2].",
102
+ ),
103
+ ReportDraftSection(
104
+ section_title="Results",
105
+ section_content="Draft content for results with [3].",
106
+ ),
107
+ ]
108
+ )
109
+
110
+ # Generate full report
111
+ report = await long_writer.write_report(
112
+ original_query="What are the main features of Python?",
113
+ report_title="Python Programming Language Overview",
114
+ report_draft=report_draft,
115
+ )
116
+
117
+ print(report)
118
+ ```
119
+
120
+ ### Writing Individual Sections
121
+
122
+ You can also write sections one at a time:
123
+
124
+ ```python
125
+ # Write first section
126
+ section_output = await long_writer.write_next_section(
127
+ original_query="What is Python?",
128
+ report_draft="", # No existing draft
129
+ next_section_title="Introduction",
130
+ next_section_draft="Python is a programming language...",
131
+ )
132
+
133
+ print(section_output.next_section_markdown)
134
+ print(section_output.references)
135
+
136
+ # Write second section with existing draft
137
+ section_output = await long_writer.write_next_section(
138
+ original_query="What is Python?",
139
+ report_draft="# Report\n\n## Introduction\n\nContent...",
140
+ next_section_title="Features",
141
+ next_section_draft="Python features include...",
142
+ )
143
+ ```
144
+
145
+ ### Integration with DeepResearchFlow
146
+
147
+ The `LongWriterAgent` is automatically used by `DeepResearchFlow`:
148
+
149
+ ```python
150
+ from src.agent_factory.agents import create_deep_flow
151
+
152
+ flow = create_deep_flow(
153
+ max_iterations=5,
154
+ max_time_minutes=10,
155
+ use_long_writer=True, # Use long writer (default)
156
+ )
157
+
158
+ report = await flow.run("What are the main features of Python programming language?")
159
+ ```
160
+
161
+ ## ProofreaderAgent
162
+
163
+ The `ProofreaderAgent` finalizes and polishes report drafts by removing duplicates, adding summaries, and refining wording.
164
+
165
+ ### Basic Usage
166
+
167
+ ```python
168
+ from src.agent_factory.agents import create_proofreader_agent
169
+ from src.utils.models import ReportDraft, ReportDraftSection
170
+
171
+ # Create proofreader agent
172
+ proofreader = create_proofreader_agent()
173
+
174
+ # Create report draft
175
+ report_draft = ReportDraft(
176
+ sections=[
177
+ ReportDraftSection(
178
+ section_title="Introduction",
179
+ section_content="Python is a programming language [1].",
180
+ ),
181
+ ReportDraftSection(
182
+ section_title="Features",
183
+ section_content="Python has many features [2].",
184
+ ),
185
+ ]
186
+ )
187
+
188
+ # Proofread and finalize
189
+ final_report = await proofreader.proofread(
190
+ query="What is Python?",
191
+ report_draft=report_draft,
192
+ )
193
+
194
+ print(final_report)
195
+ ```
196
+
197
+ ### Integration with DeepResearchFlow
198
+
199
+ Use `ProofreaderAgent` instead of `LongWriterAgent`:
200
+
201
+ ```python
202
+ from src.agent_factory.agents import create_deep_flow
203
+
204
+ flow = create_deep_flow(
205
+ max_iterations=5,
206
+ max_time_minutes=10,
207
+ use_long_writer=False, # Use proofreader instead
208
+ )
209
+
210
+ report = await flow.run("What are the main features of Python?")
211
+ ```
212
+
213
+ ## Error Handling
214
+
215
+ All writer agents include robust error handling:
216
+
217
+ ### Handling Empty Inputs
218
+
219
+ ```python
220
+ # WriterAgent handles empty findings gracefully
221
+ report = await writer.write_report(
222
+ query="Test query",
223
+ findings="", # Empty findings
224
+ )
225
+ # Returns a fallback report
226
+
227
+ # LongWriterAgent handles empty sections
228
+ report = await long_writer.write_report(
229
+ original_query="Test",
230
+ report_title="Test Report",
231
+ report_draft=ReportDraft(sections=[]), # Empty draft
232
+ )
233
+ # Returns minimal report
234
+
235
+ # ProofreaderAgent handles empty drafts
236
+ report = await proofreader.proofread(
237
+ query="Test",
238
+ report_draft=ReportDraft(sections=[]),
239
+ )
240
+ # Returns minimal report
241
+ ```
242
+
243
+ ### Retry Logic
244
+
245
+ All agents automatically retry on transient errors (timeouts, connection errors):
246
+
247
+ ```python
248
+ # Automatically retries up to 3 times on transient failures
249
+ report = await writer.write_report(
250
+ query="Test query",
251
+ findings=findings,
252
+ )
253
+ ```
254
+
255
+ ### Fallback Reports
256
+
257
+ If all retries fail, agents return fallback reports:
258
+
259
+ ```python
260
+ # Returns fallback report with query and findings
261
+ report = await writer.write_report(
262
+ query="Test query",
263
+ findings=findings,
264
+ )
265
+ # Fallback includes: "# Research Report\n\n## Query\n...\n\n## Findings\n..."
266
+ ```
267
+
268
+ ## Citation Validation
269
+
270
+ ### For Markdown Reports
271
+
272
+ Use the markdown citation validator:
273
+
274
+ ```python
275
+ from src.utils.citation_validator import validate_markdown_citations
276
+ from src.utils.models import Evidence, Citation
277
+
278
+ # Collect evidence during research
279
+ evidence = [
280
+ Evidence(
281
+ content="Paris is the capital of France",
282
+ citation=Citation(
283
+ source="web",
284
+ title="France Information",
285
+ url="https://example.com/france",
286
+ date="2024-01-01",
287
+ ),
288
+ ),
289
+ ]
290
+
291
+ # Generate report
292
+ report = await writer.write_report(query="What is the capital of France?", findings=findings)
293
+
294
+ # Validate citations
295
+ validated_report, removed_count = validate_markdown_citations(report, evidence)
296
+
297
+ if removed_count > 0:
298
+ print(f"Removed {removed_count} invalid citations")
299
+ ```
300
+
301
+ ### For ResearchReport Objects
302
+
303
+ Use the structured citation validator:
304
+
305
+ ```python
306
+ from src.utils.citation_validator import validate_references
307
+
308
+ # For ResearchReport objects (from ReportAgent)
309
+ validated_report = validate_references(report, evidence)
310
+ ```
311
+
312
+ ## Custom Model Configuration
313
+
314
+ All writer agents support custom model configuration:
315
+
316
+ ```python
317
+ from pydantic_ai import Model
318
+
319
+ # Create custom model
320
+ custom_model = Model("openai", "gpt-4")
321
+
322
+ # Use with writer agents
323
+ writer = create_writer_agent(model=custom_model)
324
+ long_writer = create_long_writer_agent(model=custom_model)
325
+ proofreader = create_proofreader_agent(model=custom_model)
326
+ ```
327
+
328
+ ## Best Practices
329
+
330
+ 1. **Use WriterAgent for simple reports** - When you have findings as a string and need a quick report
331
+ 2. **Use LongWriterAgent for structured reports** - When you need multiple sections with proper citation management
332
+ 3. **Use ProofreaderAgent for final polish** - When you have draft sections and need a polished final report
333
+ 4. **Validate citations** - Always validate citations against collected evidence
334
+ 5. **Handle errors gracefully** - All agents return fallback reports on failure
335
+ 6. **Specify output length** - Use `output_length` parameter to control report size
336
+ 7. **Provide instructions** - Use `output_instructions` for specific formatting requirements
337
+
338
+ ## Integration Examples
339
+
340
+ ### Full Iterative Research Flow
341
+
342
+ ```python
343
+ from src.agent_factory.agents import create_iterative_flow
344
+
345
+ flow = create_iterative_flow(
346
+ max_iterations=5,
347
+ max_time_minutes=10,
348
+ )
349
+
350
+ report = await flow.run(
351
+ query="What is machine learning?",
352
+ output_length="A comprehensive 1000-word explanation",
353
+ output_instructions="Include practical examples and use cases",
354
+ )
355
+ ```
356
+
357
+ ### Full Deep Research Flow with Long Writer
358
+
359
+ ```python
360
+ from src.agent_factory.agents import create_deep_flow
361
+
362
+ flow = create_deep_flow(
363
+ max_iterations=5,
364
+ max_time_minutes=10,
365
+ use_long_writer=True,
366
+ )
367
+
368
+ report = await flow.run("What are the main features of Python programming language?")
369
+ ```
370
+
371
+ ### Full Deep Research Flow with Proofreader
372
+
373
+ ```python
374
+ from src.agent_factory.agents import create_deep_flow
375
+
376
+ flow = create_deep_flow(
377
+ max_iterations=5,
378
+ max_time_minutes=10,
379
+ use_long_writer=False, # Use proofreader
380
+ )
381
+
382
+ report = await flow.run("Explain quantum computing basics")
383
+ ```
384
+
385
+ ## Troubleshooting
386
+
387
+ ### Empty Reports
388
+
389
+ If you get empty reports, check:
390
+ - Input validation logs (agents log warnings for empty inputs)
391
+ - LLM API key configuration
392
+ - Network connectivity
393
+
394
+ ### Citation Issues
395
+
396
+ If citations are missing or invalid:
397
+ - Use `validate_markdown_citations()` to check citations
398
+ - Ensure Evidence objects are properly collected during research
399
+ - Check that URLs in findings match Evidence URLs
400
+
401
+ ### Performance Issues
402
+
403
+ For large reports:
404
+ - Use `LongWriterAgent` for better section management
405
+ - Consider truncating very long findings (agents do this automatically)
406
+ - Use appropriate `max_time_minutes` settings
407
+
408
+ ## See Also
409
+
410
+ - [Research Flows Documentation](../orchestrator/research_flows.md)
411
+ - [Citation Validation](../utils/citation_validation.md)
412
+ - [Agent Factory](../agent_factory/agents.md)
413
+
414
+
415
+
docs/implementation/02_phase_search.md CHANGED
@@ -4,6 +4,8 @@
4
  **Philosophy**: "Real data, mocked connections."
5
  **Prerequisite**: Phase 1 complete (all tests passing)
6
 
 
 
7
  ---
8
 
9
  ## 1. The Slice Definition
@@ -12,17 +14,20 @@ This slice covers:
12
  1. **Input**: A string query (e.g., "metformin Alzheimer's disease").
13
  2. **Process**:
14
  - Fetch from PubMed (E-utilities API).
15
- - Fetch from Web (DuckDuckGo).
16
  - Normalize results into `Evidence` models.
17
  3. **Output**: A list of `Evidence` objects.
18
 
19
  **Files to Create**:
20
  - `src/utils/models.py` - Pydantic models (Evidence, Citation, SearchResult)
21
  - `src/tools/pubmed.py` - PubMed E-utilities tool
22
- - `src/tools/websearch.py` - DuckDuckGo search tool
23
  - `src/tools/search_handler.py` - Orchestrates multiple tools
24
  - `src/tools/__init__.py` - Exports
25
 
 
 
 
26
  ---
27
 
28
  ## 2. PubMed E-utilities API Reference
@@ -767,17 +772,23 @@ async def test_pubmed_live_search():
767
 
768
  ## 8. Implementation Checklist
769
 
770
- - [ ] Create `src/utils/models.py` with all Pydantic models (Evidence, Citation, SearchResult)
771
- - [ ] Create `src/tools/__init__.py` with SearchTool Protocol and exports
772
- - [ ] Implement `src/tools/pubmed.py` with PubMedTool class
773
- - [ ] Implement `src/tools/websearch.py` with WebTool class
774
- - [ ] Create `src/tools/search_handler.py` with SearchHandler class
775
- - [ ] Write tests in `tests/unit/tools/test_pubmed.py`
776
- - [ ] Write tests in `tests/unit/tools/test_websearch.py`
777
- - [ ] Write tests in `tests/unit/tools/test_search_handler.py`
778
- - [ ] Run `uv run pytest tests/unit/tools/ -v` — **ALL TESTS MUST PASS**
779
  - [ ] (Optional) Run integration test: `uv run pytest -m integration`
780
- - [ ] Commit: `git commit -m "feat: phase 2 search slice complete"`
 
 
 
 
 
 
781
 
782
  ---
783
 
@@ -785,20 +796,19 @@ async def test_pubmed_live_search():
785
 
786
  Phase 2 is **COMPLETE** when:
787
 
788
- 1. All unit tests pass: `uv run pytest tests/unit/tools/ -v`
789
- 2. `SearchHandler` can execute with both tools
790
- 3. Graceful degradation: if PubMed fails, WebTool results still return
791
- 4. Rate limiting is enforced (verify no 429 errors)
792
- 5. Can run this in Python REPL:
793
 
794
  ```python
795
  import asyncio
796
  from src.tools.pubmed import PubMedTool
797
- from src.tools.websearch import WebTool
798
  from src.tools.search_handler import SearchHandler
799
 
800
  async def test():
801
- handler = SearchHandler([PubMedTool(), WebTool()])
802
  result = await handler.execute("metformin alzheimer")
803
  print(f"Found {result.total_found} results")
804
  for e in result.evidence[:3]:
@@ -807,4 +817,6 @@ async def test():
807
  asyncio.run(test())
808
  ```
809
 
 
 
810
  **Proceed to Phase 3 ONLY after all checkboxes are complete.**
 
4
  **Philosophy**: "Real data, mocked connections."
5
  **Prerequisite**: Phase 1 complete (all tests passing)
6
 
7
+ > **⚠️ Implementation Note (2025-01-27)**: The DuckDuckGo WebTool specified in this phase was removed in favor of the Europe PMC tool (see Phase 11). Europe PMC provides better coverage for biomedical research by including preprints, peer-reviewed articles, and patents. The current implementation uses PubMed, ClinicalTrials.gov, and Europe PMC as search sources.
8
+
9
  ---
10
 
11
  ## 1. The Slice Definition
 
14
  1. **Input**: A string query (e.g., "metformin Alzheimer's disease").
15
  2. **Process**:
16
  - Fetch from PubMed (E-utilities API).
17
+ - ~~Fetch from Web (DuckDuckGo).~~ **REMOVED** - Replaced by Europe PMC in Phase 11
18
  - Normalize results into `Evidence` models.
19
  3. **Output**: A list of `Evidence` objects.
20
 
21
  **Files to Create**:
22
  - `src/utils/models.py` - Pydantic models (Evidence, Citation, SearchResult)
23
  - `src/tools/pubmed.py` - PubMed E-utilities tool
24
+ - ~~`src/tools/websearch.py` - DuckDuckGo search tool~~ **REMOVED** - See Phase 11 for Europe PMC replacement
25
  - `src/tools/search_handler.py` - Orchestrates multiple tools
26
  - `src/tools/__init__.py` - Exports
27
 
28
+ **Additional Files (Post-Phase 2 Enhancements)**:
29
+ - `src/tools/query_utils.py` - Query preprocessing (removes question words, expands medical synonyms)
30
+
31
  ---
32
 
33
  ## 2. PubMed E-utilities API Reference
 
772
 
773
  ## 8. Implementation Checklist
774
 
775
+ - [x] Create `src/utils/models.py` with all Pydantic models (Evidence, Citation, SearchResult) - **COMPLETE**
776
+ - [x] Create `src/tools/__init__.py` with SearchTool Protocol and exports - **COMPLETE**
777
+ - [x] Implement `src/tools/pubmed.py` with PubMedTool class - **COMPLETE**
778
+ - [ ] ~~Implement `src/tools/websearch.py` with WebTool class~~ - **REMOVED** (replaced by Europe PMC in Phase 11)
779
+ - [x] Create `src/tools/search_handler.py` with SearchHandler class - **COMPLETE**
780
+ - [x] Write tests in `tests/unit/tools/test_pubmed.py` - **COMPLETE** (basic tests)
781
+ - [ ] Write tests in `tests/unit/tools/test_websearch.py` - **N/A** (WebTool removed)
782
+ - [x] Write tests in `tests/unit/tools/test_search_handler.py` - **COMPLETE** (basic tests)
783
+ - [x] Run `uv run pytest tests/unit/tools/ -v` — **ALL TESTS MUST PASS** - **PASSING**
784
  - [ ] (Optional) Run integration test: `uv run pytest -m integration`
785
+ - [ ] Add edge case tests (rate limiting, error handling, timeouts) - **PENDING**
786
+ - [ ] Commit: `git commit -m "feat: phase 2 search slice complete"` - **DONE**
787
+
788
+ **Post-Phase 2 Enhancements**:
789
+ - [x] Query preprocessing (`src/tools/query_utils.py`) - **ADDED**
790
+ - [x] Europe PMC tool (Phase 11) - **ADDED**
791
+ - [x] ClinicalTrials tool (Phase 10) - **ADDED**
792
 
793
  ---
794
 
 
796
 
797
  Phase 2 is **COMPLETE** when:
798
 
799
+ 1. All unit tests pass: `uv run pytest tests/unit/tools/ -v` - **PASSING**
800
+ 2. `SearchHandler` can execute with search tools - **WORKING**
801
+ 3. Graceful degradation: if one tool fails, other tools still return results - **IMPLEMENTED**
802
+ 4. Rate limiting is enforced (verify no 429 errors) - **IMPLEMENTED**
803
+ 5. Can run this in Python REPL:
804
 
805
  ```python
806
  import asyncio
807
  from src.tools.pubmed import PubMedTool
 
808
  from src.tools.search_handler import SearchHandler
809
 
810
  async def test():
811
+ handler = SearchHandler([PubMedTool()])
812
  result = await handler.execute("metformin alzheimer")
813
  print(f"Found {result.total_found} results")
814
  for e in result.evidence[:3]:
 
817
  asyncio.run(test())
818
  ```
819
 
820
+ **Note**: WebTool was removed in favor of Europe PMC (Phase 11). The current implementation uses PubMed as the primary Phase 2 tool, with Europe PMC and ClinicalTrials added in later phases.
821
+
822
  **Proceed to Phase 3 ONLY after all checkboxes are complete.**
pyproject.toml CHANGED
@@ -24,6 +24,7 @@ dependencies = [
24
  "tenacity>=8.2", # Retry logic
25
  "structlog>=24.1", # Structured logging
26
  "requests>=2.32.5", # ClinicalTrials.gov (httpx blocked by WAF)
 
27
  ]
28
 
29
  [project.optional-dependencies]
@@ -91,6 +92,7 @@ ignore = [
91
  "PLW0603", # Global statement (singleton pattern for Modal)
92
  "PLC0415", # Lazy imports for optional dependencies
93
  "E402", # Module level import not at top (needed for pytest.importorskip)
 
94
  "RUF100", # Unused noqa (version differences between local/CI)
95
  ]
96
 
@@ -105,9 +107,12 @@ ignore_missing_imports = true
105
  disallow_untyped_defs = true
106
  warn_return_any = true
107
  warn_unused_ignores = false
 
 
108
  exclude = [
109
  "^reference_repos/",
110
  "^examples/",
 
111
  ]
112
 
113
  # ============== PYTEST CONFIG ==============
@@ -137,5 +142,11 @@ exclude_lines = [
137
  "raise NotImplementedError",
138
  ]
139
 
 
 
 
 
 
 
140
  # Note: agent-framework-core is optional for magentic mode (multi-agent orchestration)
141
  # Version pinned to 1.0.0b* to avoid breaking changes. CI skips tests via pytest.importorskip
 
24
  "tenacity>=8.2", # Retry logic
25
  "structlog>=24.1", # Structured logging
26
  "requests>=2.32.5", # ClinicalTrials.gov (httpx blocked by WAF)
27
+ "pydantic-graph>=1.22.0",
28
  ]
29
 
30
  [project.optional-dependencies]
 
92
  "PLW0603", # Global statement (singleton pattern for Modal)
93
  "PLC0415", # Lazy imports for optional dependencies
94
  "E402", # Module level import not at top (needed for pytest.importorskip)
95
+ "E501", # Line too long (ignore line length violations)
96
  "RUF100", # Unused noqa (version differences between local/CI)
97
  ]
98
 
 
107
  disallow_untyped_defs = true
108
  warn_return_any = true
109
  warn_unused_ignores = false
110
+ explicit_package_bases = true
111
+ mypy_path = "."
112
  exclude = [
113
  "^reference_repos/",
114
  "^examples/",
115
+ "^folder/",
116
  ]
117
 
118
  # ============== PYTEST CONFIG ==============
 
142
  "raise NotImplementedError",
143
  ]
144
 
145
+ [dependency-groups]
146
+ dev = [
147
+ "structlog>=25.5.0",
148
+ "ty>=0.0.1a28",
149
+ ]
150
+
151
  # Note: agent-framework-core is optional for magentic mode (multi-agent orchestration)
152
  # Version pinned to 1.0.0b* to avoid breaking changes. CI skips tests via pytest.importorskip
src/agent_factory/agents.py CHANGED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Agent factory functions for creating research agents.
2
+
3
+ Provides factory functions for creating all Pydantic AI agents used in
4
+ the research workflows, following the pattern from judges.py.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ import structlog
10
+
11
+ from src.utils.config import settings
12
+ from src.utils.exceptions import ConfigurationError
13
+
14
+ if TYPE_CHECKING:
15
+ from src.agent_factory.graph_builder import GraphBuilder
16
+ from src.agents.input_parser import InputParserAgent
17
+ from src.agents.knowledge_gap import KnowledgeGapAgent
18
+ from src.agents.long_writer import LongWriterAgent
19
+ from src.agents.proofreader import ProofreaderAgent
20
+ from src.agents.thinking import ThinkingAgent
21
+ from src.agents.tool_selector import ToolSelectorAgent
22
+ from src.agents.writer import WriterAgent
23
+ from src.orchestrator.graph_orchestrator import GraphOrchestrator
24
+ from src.orchestrator.planner_agent import PlannerAgent
25
+ from src.orchestrator.research_flow import DeepResearchFlow, IterativeResearchFlow
26
+
27
+ logger = structlog.get_logger()
28
+
29
+
30
+ def create_input_parser_agent(model: Any | None = None) -> "InputParserAgent":
31
+ """
32
+ Create input parser agent for query analysis and research mode detection.
33
+
34
+ Args:
35
+ model: Optional Pydantic AI model. If None, uses settings default.
36
+
37
+ Returns:
38
+ Configured InputParserAgent instance
39
+
40
+ Raises:
41
+ ConfigurationError: If required API keys are missing
42
+ """
43
+ from src.agents.input_parser import create_input_parser_agent as _create_agent
44
+
45
+ try:
46
+ logger.debug("Creating input parser agent")
47
+ return _create_agent(model=model)
48
+ except Exception as e:
49
+ logger.error("Failed to create input parser agent", error=str(e))
50
+ raise ConfigurationError(f"Failed to create input parser agent: {e}") from e
51
+
52
+
53
+ def create_planner_agent(model: Any | None = None) -> "PlannerAgent":
54
+ """
55
+ Create planner agent with web search and crawl tools.
56
+
57
+ Args:
58
+ model: Optional Pydantic AI model. If None, uses settings default.
59
+
60
+ Returns:
61
+ Configured PlannerAgent instance
62
+
63
+ Raises:
64
+ ConfigurationError: If required API keys are missing
65
+ """
66
+ # Lazy import to avoid circular dependencies
67
+ from src.orchestrator.planner_agent import create_planner_agent as _create_planner_agent
68
+
69
+ try:
70
+ logger.debug("Creating planner agent")
71
+ return _create_planner_agent(model=model)
72
+ except Exception as e:
73
+ logger.error("Failed to create planner agent", error=str(e))
74
+ raise ConfigurationError(f"Failed to create planner agent: {e}") from e
75
+
76
+
77
+ def create_knowledge_gap_agent(model: Any | None = None) -> "KnowledgeGapAgent":
78
+ """
79
+ Create knowledge gap agent for evaluating research completeness.
80
+
81
+ Args:
82
+ model: Optional Pydantic AI model. If None, uses settings default.
83
+
84
+ Returns:
85
+ Configured KnowledgeGapAgent instance
86
+
87
+ Raises:
88
+ ConfigurationError: If required API keys are missing
89
+ """
90
+ from src.agents.knowledge_gap import create_knowledge_gap_agent as _create_agent
91
+
92
+ try:
93
+ logger.debug("Creating knowledge gap agent")
94
+ return _create_agent(model=model)
95
+ except Exception as e:
96
+ logger.error("Failed to create knowledge gap agent", error=str(e))
97
+ raise ConfigurationError(f"Failed to create knowledge gap agent: {e}") from e
98
+
99
+
100
+ def create_tool_selector_agent(model: Any | None = None) -> "ToolSelectorAgent":
101
+ """
102
+ Create tool selector agent for choosing tools to address gaps.
103
+
104
+ Args:
105
+ model: Optional Pydantic AI model. If None, uses settings default.
106
+
107
+ Returns:
108
+ Configured ToolSelectorAgent instance
109
+
110
+ Raises:
111
+ ConfigurationError: If required API keys are missing
112
+ """
113
+ from src.agents.tool_selector import create_tool_selector_agent as _create_agent
114
+
115
+ try:
116
+ logger.debug("Creating tool selector agent")
117
+ return _create_agent(model=model)
118
+ except Exception as e:
119
+ logger.error("Failed to create tool selector agent", error=str(e))
120
+ raise ConfigurationError(f"Failed to create tool selector agent: {e}") from e
121
+
122
+
123
+ def create_thinking_agent(model: Any | None = None) -> "ThinkingAgent":
124
+ """
125
+ Create thinking agent for generating observations.
126
+
127
+ Args:
128
+ model: Optional Pydantic AI model. If None, uses settings default.
129
+
130
+ Returns:
131
+ Configured ThinkingAgent instance
132
+
133
+ Raises:
134
+ ConfigurationError: If required API keys are missing
135
+ """
136
+ from src.agents.thinking import create_thinking_agent as _create_agent
137
+
138
+ try:
139
+ logger.debug("Creating thinking agent")
140
+ return _create_agent(model=model)
141
+ except Exception as e:
142
+ logger.error("Failed to create thinking agent", error=str(e))
143
+ raise ConfigurationError(f"Failed to create thinking agent: {e}") from e
144
+
145
+
146
+ def create_writer_agent(model: Any | None = None) -> "WriterAgent":
147
+ """
148
+ Create writer agent for generating final reports.
149
+
150
+ Args:
151
+ model: Optional Pydantic AI model. If None, uses settings default.
152
+
153
+ Returns:
154
+ Configured WriterAgent instance
155
+
156
+ Raises:
157
+ ConfigurationError: If required API keys are missing
158
+ """
159
+ from src.agents.writer import create_writer_agent as _create_agent
160
+
161
+ try:
162
+ logger.debug("Creating writer agent")
163
+ return _create_agent(model=model)
164
+ except Exception as e:
165
+ logger.error("Failed to create writer agent", error=str(e))
166
+ raise ConfigurationError(f"Failed to create writer agent: {e}") from e
167
+
168
+
169
+ def create_long_writer_agent(model: Any | None = None) -> "LongWriterAgent":
170
+ """
171
+ Create long writer agent for iteratively writing report sections.
172
+
173
+ Args:
174
+ model: Optional Pydantic AI model. If None, uses settings default.
175
+
176
+ Returns:
177
+ Configured LongWriterAgent instance
178
+
179
+ Raises:
180
+ ConfigurationError: If required API keys are missing
181
+ """
182
+ from src.agents.long_writer import create_long_writer_agent as _create_agent
183
+
184
+ try:
185
+ logger.debug("Creating long writer agent")
186
+ return _create_agent(model=model)
187
+ except Exception as e:
188
+ logger.error("Failed to create long writer agent", error=str(e))
189
+ raise ConfigurationError(f"Failed to create long writer agent: {e}") from e
190
+
191
+
192
+ def create_proofreader_agent(model: Any | None = None) -> "ProofreaderAgent":
193
+ """
194
+ Create proofreader agent for finalizing report drafts.
195
+
196
+ Args:
197
+ model: Optional Pydantic AI model. If None, uses settings default.
198
+
199
+ Returns:
200
+ Configured ProofreaderAgent instance
201
+
202
+ Raises:
203
+ ConfigurationError: If required API keys are missing
204
+ """
205
+ from src.agents.proofreader import create_proofreader_agent as _create_agent
206
+
207
+ try:
208
+ logger.debug("Creating proofreader agent")
209
+ return _create_agent(model=model)
210
+ except Exception as e:
211
+ logger.error("Failed to create proofreader agent", error=str(e))
212
+ raise ConfigurationError(f"Failed to create proofreader agent: {e}") from e
213
+
214
+
215
+ def create_iterative_flow(
216
+ max_iterations: int = 5,
217
+ max_time_minutes: int = 10,
218
+ verbose: bool = True,
219
+ use_graph: bool | None = None,
220
+ ) -> "IterativeResearchFlow":
221
+ """
222
+ Create iterative research flow.
223
+
224
+ Args:
225
+ max_iterations: Maximum number of iterations
226
+ max_time_minutes: Maximum time in minutes
227
+ verbose: Whether to log progress
228
+ use_graph: Whether to use graph execution. If None, reads from settings.use_graph_execution
229
+
230
+ Returns:
231
+ Configured IterativeResearchFlow instance
232
+ """
233
+ from src.orchestrator.research_flow import IterativeResearchFlow
234
+
235
+ try:
236
+ # Use settings default if not explicitly provided
237
+ if use_graph is None:
238
+ use_graph = settings.use_graph_execution
239
+
240
+ logger.debug("Creating iterative research flow", use_graph=use_graph)
241
+ return IterativeResearchFlow(
242
+ max_iterations=max_iterations,
243
+ max_time_minutes=max_time_minutes,
244
+ verbose=verbose,
245
+ use_graph=use_graph,
246
+ )
247
+ except Exception as e:
248
+ logger.error("Failed to create iterative flow", error=str(e))
249
+ raise ConfigurationError(f"Failed to create iterative flow: {e}") from e
250
+
251
+
252
+ def create_deep_flow(
253
+ max_iterations: int = 5,
254
+ max_time_minutes: int = 10,
255
+ verbose: bool = True,
256
+ use_long_writer: bool = True,
257
+ use_graph: bool | None = None,
258
+ ) -> "DeepResearchFlow":
259
+ """
260
+ Create deep research flow.
261
+
262
+ Args:
263
+ max_iterations: Maximum iterations per section
264
+ max_time_minutes: Maximum time per section
265
+ verbose: Whether to log progress
266
+ use_long_writer: Whether to use long writer (True) or proofreader (False)
267
+ use_graph: Whether to use graph execution. If None, reads from settings.use_graph_execution
268
+
269
+ Returns:
270
+ Configured DeepResearchFlow instance
271
+ """
272
+ from src.orchestrator.research_flow import DeepResearchFlow
273
+
274
+ try:
275
+ # Use settings default if not explicitly provided
276
+ if use_graph is None:
277
+ use_graph = settings.use_graph_execution
278
+
279
+ logger.debug("Creating deep research flow", use_graph=use_graph)
280
+ return DeepResearchFlow(
281
+ max_iterations=max_iterations,
282
+ max_time_minutes=max_time_minutes,
283
+ verbose=verbose,
284
+ use_long_writer=use_long_writer,
285
+ use_graph=use_graph,
286
+ )
287
+ except Exception as e:
288
+ logger.error("Failed to create deep flow", error=str(e))
289
+ raise ConfigurationError(f"Failed to create deep flow: {e}") from e
290
+
291
+
292
+ def create_graph_orchestrator(
293
+ mode: str = "auto",
294
+ max_iterations: int = 5,
295
+ max_time_minutes: int = 10,
296
+ use_graph: bool = True,
297
+ ) -> "GraphOrchestrator":
298
+ """
299
+ Create graph orchestrator.
300
+
301
+ Args:
302
+ mode: Research mode ("iterative", "deep", or "auto")
303
+ max_iterations: Maximum iterations per loop
304
+ max_time_minutes: Maximum time per loop
305
+ use_graph: Whether to use graph execution (True) or agent chains (False)
306
+
307
+ Returns:
308
+ Configured GraphOrchestrator instance
309
+ """
310
+ from src.orchestrator.graph_orchestrator import create_graph_orchestrator as _create
311
+
312
+ try:
313
+ logger.debug("Creating graph orchestrator", mode=mode, use_graph=use_graph)
314
+ return _create(
315
+ mode=mode, # type: ignore[arg-type]
316
+ max_iterations=max_iterations,
317
+ max_time_minutes=max_time_minutes,
318
+ use_graph=use_graph,
319
+ )
320
+ except Exception as e:
321
+ logger.error("Failed to create graph orchestrator", error=str(e))
322
+ raise ConfigurationError(f"Failed to create graph orchestrator: {e}") from e
323
+
324
+
325
+ def create_graph_builder() -> "GraphBuilder":
326
+ """
327
+ Create a graph builder instance.
328
+
329
+ Returns:
330
+ GraphBuilder instance
331
+ """
332
+ from src.agent_factory.graph_builder import GraphBuilder
333
+
334
+ try:
335
+ logger.debug("Creating graph builder")
336
+ return GraphBuilder()
337
+ except Exception as e:
338
+ logger.error("Failed to create graph builder", error=str(e))
339
+ raise ConfigurationError(f"Failed to create graph builder: {e}") from e
src/agent_factory/graph_builder.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Graph builder utilities for constructing research workflow graphs.
2
+
3
+ Provides classes and utilities for building graph-based orchestration systems
4
+ using Pydantic AI agents as nodes.
5
+ """
6
+
7
+ from collections.abc import Callable
8
+ from typing import TYPE_CHECKING, Any, Literal
9
+
10
+ import structlog
11
+ from pydantic import BaseModel, Field
12
+
13
+ if TYPE_CHECKING:
14
+ from pydantic_ai import Agent
15
+
16
+ from src.middleware.state_machine import WorkflowState
17
+
18
+ logger = structlog.get_logger()
19
+
20
+
21
+ # ============================================================================
22
+ # Graph Node Models
23
+ # ============================================================================
24
+
25
+
26
+ class GraphNode(BaseModel):
27
+ """Base class for graph nodes."""
28
+
29
+ node_id: str = Field(description="Unique identifier for the node")
30
+ node_type: Literal["agent", "state", "decision", "parallel"] = Field(description="Type of node")
31
+ description: str = Field(default="", description="Human-readable description of the node")
32
+
33
+ model_config = {"frozen": True}
34
+
35
+
36
+ class AgentNode(GraphNode):
37
+ """Node that executes a Pydantic AI agent."""
38
+
39
+ node_type: Literal["agent"] = "agent"
40
+ agent: Any = Field(description="Pydantic AI agent to execute")
41
+ input_transformer: Callable[[Any], Any] | None = Field(
42
+ default=None, description="Transform input before passing to agent"
43
+ )
44
+ output_transformer: Callable[[Any], Any] | None = Field(
45
+ default=None, description="Transform output after agent execution"
46
+ )
47
+
48
+ model_config = {"arbitrary_types_allowed": True}
49
+
50
+
51
+ class StateNode(GraphNode):
52
+ """Node that updates or reads workflow state."""
53
+
54
+ node_type: Literal["state"] = "state"
55
+ state_updater: Callable[[Any, Any], Any] = Field(
56
+ description="Function to update workflow state"
57
+ )
58
+ state_reader: Callable[[Any], Any] | None = Field(
59
+ default=None, description="Function to read state (optional)"
60
+ )
61
+
62
+ model_config = {"arbitrary_types_allowed": True}
63
+
64
+
65
+ class DecisionNode(GraphNode):
66
+ """Node that makes routing decisions based on conditions."""
67
+
68
+ node_type: Literal["decision"] = "decision"
69
+ decision_function: Callable[[Any], str] = Field(
70
+ description="Function that returns next node ID based on input"
71
+ )
72
+ options: list[str] = Field(description="List of possible next node IDs", min_length=1)
73
+
74
+ model_config = {"arbitrary_types_allowed": True}
75
+
76
+
77
+ class ParallelNode(GraphNode):
78
+ """Node that executes multiple nodes in parallel."""
79
+
80
+ node_type: Literal["parallel"] = "parallel"
81
+ parallel_nodes: list[str] = Field(
82
+ description="List of node IDs to run in parallel", min_length=1
83
+ )
84
+ aggregator: Callable[[list[Any]], Any] | None = Field(
85
+ default=None, description="Function to aggregate parallel results"
86
+ )
87
+
88
+ model_config = {"arbitrary_types_allowed": True}
89
+
90
+
91
+ # ============================================================================
92
+ # Graph Edge Models
93
+ # ============================================================================
94
+
95
+
96
+ class GraphEdge(BaseModel):
97
+ """Base class for graph edges."""
98
+
99
+ from_node: str = Field(description="Source node ID")
100
+ to_node: str = Field(description="Target node ID")
101
+ condition: Callable[[Any], bool] | None = Field(
102
+ default=None, description="Optional condition function"
103
+ )
104
+ weight: float = Field(default=1.0, description="Edge weight for routing decisions")
105
+
106
+ model_config = {"arbitrary_types_allowed": True}
107
+
108
+
109
+ class SequentialEdge(GraphEdge):
110
+ """Edge that is always traversed (no condition)."""
111
+
112
+ condition: None = None
113
+
114
+
115
+ class ConditionalEdge(GraphEdge):
116
+ """Edge that is traversed based on a condition."""
117
+
118
+ condition: Callable[[Any], bool] = Field(description="Required condition function")
119
+ condition_description: str = Field(
120
+ default="", description="Human-readable description of condition"
121
+ )
122
+
123
+
124
+ class ParallelEdge(GraphEdge):
125
+ """Edge used for parallel execution branches."""
126
+
127
+ condition: None = None
128
+
129
+
130
+ # ============================================================================
131
+ # Research Graph Class
132
+ # ============================================================================
133
+
134
+
135
+ class ResearchGraph(BaseModel):
136
+ """Represents a research workflow graph with nodes and edges."""
137
+
138
+ nodes: dict[str, GraphNode] = Field(default_factory=dict, description="All nodes in the graph")
139
+ edges: dict[str, list[GraphEdge]] = Field(
140
+ default_factory=dict, description="Edges by source node ID"
141
+ )
142
+ entry_node: str = Field(description="Starting node ID")
143
+ exit_nodes: list[str] = Field(default_factory=list, description="Terminal node IDs")
144
+
145
+ model_config = {"arbitrary_types_allowed": True}
146
+
147
+ def add_node(self, node: GraphNode) -> None:
148
+ """Add a node to the graph.
149
+
150
+ Args:
151
+ node: The node to add
152
+
153
+ Raises:
154
+ ValueError: If node ID already exists
155
+ """
156
+ if node.node_id in self.nodes:
157
+ raise ValueError(f"Node {node.node_id} already exists in graph")
158
+ self.nodes[node.node_id] = node
159
+ logger.debug("Node added to graph", node_id=node.node_id, type=node.node_type)
160
+
161
+ def add_edge(self, edge: GraphEdge) -> None:
162
+ """Add an edge to the graph.
163
+
164
+ Args:
165
+ edge: The edge to add
166
+
167
+ Raises:
168
+ ValueError: If source or target node doesn't exist
169
+ """
170
+ if edge.from_node not in self.nodes:
171
+ raise ValueError(f"Source node {edge.from_node} not found in graph")
172
+ if edge.to_node not in self.nodes:
173
+ raise ValueError(f"Target node {edge.to_node} not found in graph")
174
+
175
+ if edge.from_node not in self.edges:
176
+ self.edges[edge.from_node] = []
177
+ self.edges[edge.from_node].append(edge)
178
+ logger.debug(
179
+ "Edge added to graph",
180
+ from_node=edge.from_node,
181
+ to_node=edge.to_node,
182
+ )
183
+
184
+ def get_node(self, node_id: str) -> GraphNode | None:
185
+ """Get a node by ID.
186
+
187
+ Args:
188
+ node_id: The node ID
189
+
190
+ Returns:
191
+ The node, or None if not found
192
+ """
193
+ return self.nodes.get(node_id)
194
+
195
+ def get_next_nodes(self, node_id: str, context: Any = None) -> list[tuple[str, GraphEdge]]:
196
+ """Get all possible next nodes from a given node.
197
+
198
+ Args:
199
+ node_id: The current node ID
200
+ context: Optional context for evaluating conditions
201
+
202
+ Returns:
203
+ List of (node_id, edge) tuples for valid next nodes
204
+ """
205
+ if node_id not in self.edges:
206
+ return []
207
+
208
+ next_nodes = []
209
+ for edge in self.edges[node_id]:
210
+ # Evaluate condition if present
211
+ if edge.condition is None or edge.condition(context):
212
+ next_nodes.append((edge.to_node, edge))
213
+
214
+ return next_nodes
215
+
216
+ def validate_structure(self) -> list[str]:
217
+ """Validate the graph structure.
218
+
219
+ Returns:
220
+ List of validation error messages (empty if valid)
221
+ """
222
+ errors = []
223
+
224
+ # Check entry node exists
225
+ if self.entry_node not in self.nodes:
226
+ errors.append(f"Entry node {self.entry_node} not found in graph")
227
+
228
+ # Check exit nodes exist and at least one is defined
229
+ if not self.exit_nodes:
230
+ errors.append("At least one exit node must be defined")
231
+ for exit_node in self.exit_nodes:
232
+ if exit_node not in self.nodes:
233
+ errors.append(f"Exit node {exit_node} not found in graph")
234
+
235
+ # Check all edges reference valid nodes
236
+ for from_node, edge_list in self.edges.items():
237
+ if from_node not in self.nodes:
238
+ errors.append(f"Edge source node {from_node} not found")
239
+ for edge in edge_list:
240
+ if edge.to_node not in self.nodes:
241
+ errors.append(f"Edge target node {edge.to_node} not found")
242
+
243
+ # Check all nodes are reachable from entry node (basic check)
244
+ if self.entry_node in self.nodes:
245
+ reachable = {self.entry_node}
246
+ queue = [self.entry_node]
247
+ while queue:
248
+ current = queue.pop(0)
249
+ for next_node, _ in self.get_next_nodes(current):
250
+ if next_node not in reachable:
251
+ reachable.add(next_node)
252
+ queue.append(next_node)
253
+
254
+ unreachable = set(self.nodes.keys()) - reachable
255
+ if unreachable:
256
+ errors.append(f"Unreachable nodes from entry node: {', '.join(unreachable)}")
257
+
258
+ return errors
259
+
260
+
261
+ # ============================================================================
262
+ # Graph Builder Class
263
+ # ============================================================================
264
+
265
+
266
+ class GraphBuilder:
267
+ """Builder for constructing research workflow graphs."""
268
+
269
+ def __init__(self) -> None:
270
+ """Initialize the graph builder."""
271
+ self.graph = ResearchGraph(entry_node="", exit_nodes=[])
272
+
273
+ def add_agent_node(
274
+ self,
275
+ node_id: str,
276
+ agent: "Agent[Any, Any]",
277
+ description: str = "",
278
+ input_transformer: Callable[[Any], Any] | None = None,
279
+ output_transformer: Callable[[Any], Any] | None = None,
280
+ ) -> "GraphBuilder":
281
+ """Add an agent node to the graph.
282
+
283
+ Args:
284
+ node_id: Unique identifier for the node
285
+ agent: Pydantic AI agent to execute
286
+ description: Human-readable description
287
+ input_transformer: Optional input transformation function
288
+ output_transformer: Optional output transformation function
289
+
290
+ Returns:
291
+ Self for method chaining
292
+ """
293
+ node = AgentNode(
294
+ node_id=node_id,
295
+ agent=agent,
296
+ description=description,
297
+ input_transformer=input_transformer,
298
+ output_transformer=output_transformer,
299
+ )
300
+ self.graph.add_node(node)
301
+ return self
302
+
303
+ def add_state_node(
304
+ self,
305
+ node_id: str,
306
+ state_updater: Callable[["WorkflowState", Any], "WorkflowState"],
307
+ description: str = "",
308
+ state_reader: Callable[["WorkflowState"], Any] | None = None,
309
+ ) -> "GraphBuilder":
310
+ """Add a state node to the graph.
311
+
312
+ Args:
313
+ node_id: Unique identifier for the node
314
+ state_updater: Function to update workflow state
315
+ description: Human-readable description
316
+ state_reader: Optional function to read state
317
+
318
+ Returns:
319
+ Self for method chaining
320
+ """
321
+ node = StateNode(
322
+ node_id=node_id,
323
+ state_updater=state_updater,
324
+ description=description,
325
+ state_reader=state_reader,
326
+ )
327
+ self.graph.add_node(node)
328
+ return self
329
+
330
+ def add_decision_node(
331
+ self,
332
+ node_id: str,
333
+ decision_function: Callable[[Any], str],
334
+ options: list[str],
335
+ description: str = "",
336
+ ) -> "GraphBuilder":
337
+ """Add a decision node to the graph.
338
+
339
+ Args:
340
+ node_id: Unique identifier for the node
341
+ decision_function: Function that returns next node ID
342
+ options: List of possible next node IDs
343
+ description: Human-readable description
344
+
345
+ Returns:
346
+ Self for method chaining
347
+ """
348
+ node = DecisionNode(
349
+ node_id=node_id,
350
+ decision_function=decision_function,
351
+ options=options,
352
+ description=description,
353
+ )
354
+ self.graph.add_node(node)
355
+ return self
356
+
357
+ def add_parallel_node(
358
+ self,
359
+ node_id: str,
360
+ parallel_nodes: list[str],
361
+ description: str = "",
362
+ aggregator: Callable[[list[Any]], Any] | None = None,
363
+ ) -> "GraphBuilder":
364
+ """Add a parallel node to the graph.
365
+
366
+ Args:
367
+ node_id: Unique identifier for the node
368
+ parallel_nodes: List of node IDs to run in parallel
369
+ description: Human-readable description
370
+ aggregator: Optional function to aggregate results
371
+
372
+ Returns:
373
+ Self for method chaining
374
+ """
375
+ node = ParallelNode(
376
+ node_id=node_id,
377
+ parallel_nodes=parallel_nodes,
378
+ description=description,
379
+ aggregator=aggregator,
380
+ )
381
+ self.graph.add_node(node)
382
+ return self
383
+
384
+ def connect_nodes(
385
+ self,
386
+ from_node: str,
387
+ to_node: str,
388
+ condition: Callable[[Any], bool] | None = None,
389
+ condition_description: str = "",
390
+ ) -> "GraphBuilder":
391
+ """Connect two nodes with an edge.
392
+
393
+ Args:
394
+ from_node: Source node ID
395
+ to_node: Target node ID
396
+ condition: Optional condition function
397
+ condition_description: Description of condition (if conditional)
398
+
399
+ Returns:
400
+ Self for method chaining
401
+ """
402
+ if condition is None:
403
+ edge: GraphEdge = SequentialEdge(from_node=from_node, to_node=to_node)
404
+ else:
405
+ edge = ConditionalEdge(
406
+ from_node=from_node,
407
+ to_node=to_node,
408
+ condition=condition,
409
+ condition_description=condition_description,
410
+ )
411
+ self.graph.add_edge(edge)
412
+ return self
413
+
414
+ def set_entry_node(self, node_id: str) -> "GraphBuilder":
415
+ """Set the entry node for the graph.
416
+
417
+ Args:
418
+ node_id: The entry node ID
419
+
420
+ Returns:
421
+ Self for method chaining
422
+ """
423
+ self.graph.entry_node = node_id
424
+ return self
425
+
426
+ def set_exit_nodes(self, node_ids: list[str]) -> "GraphBuilder":
427
+ """Set the exit nodes for the graph.
428
+
429
+ Args:
430
+ node_ids: List of exit node IDs
431
+
432
+ Returns:
433
+ Self for method chaining
434
+ """
435
+ self.graph.exit_nodes = node_ids
436
+ return self
437
+
438
+ def build(self) -> ResearchGraph:
439
+ """Finalize graph construction and validate.
440
+
441
+ Returns:
442
+ The constructed ResearchGraph
443
+
444
+ Raises:
445
+ ValueError: If graph validation fails
446
+ """
447
+ errors = self.graph.validate_structure()
448
+ if errors:
449
+ error_msg = "Graph validation failed:\n" + "\n".join(f" - {e}" for e in errors)
450
+ logger.error("Graph validation failed", errors=errors)
451
+ raise ValueError(error_msg)
452
+
453
+ logger.info(
454
+ "Graph built successfully",
455
+ nodes=len(self.graph.nodes),
456
+ edges=sum(len(edges) for edges in self.graph.edges.values()),
457
+ entry_node=self.graph.entry_node,
458
+ exit_nodes=self.graph.exit_nodes,
459
+ )
460
+ return self.graph
461
+
462
+
463
+ # ============================================================================
464
+ # Factory Functions
465
+ # ============================================================================
466
+
467
+
468
+ def create_iterative_graph(
469
+ knowledge_gap_agent: "Agent[Any, Any]",
470
+ tool_selector_agent: "Agent[Any, Any]",
471
+ thinking_agent: "Agent[Any, Any]",
472
+ writer_agent: "Agent[Any, Any]",
473
+ ) -> ResearchGraph:
474
+ """Create a graph for iterative research flow.
475
+
476
+ Args:
477
+ knowledge_gap_agent: Agent for evaluating knowledge gaps
478
+ tool_selector_agent: Agent for selecting tools
479
+ thinking_agent: Agent for generating observations
480
+ writer_agent: Agent for writing final report
481
+
482
+ Returns:
483
+ Constructed ResearchGraph for iterative research
484
+ """
485
+ builder = GraphBuilder()
486
+
487
+ # Add nodes
488
+ builder.add_agent_node("thinking", thinking_agent, "Generate observations")
489
+ builder.add_agent_node("knowledge_gap", knowledge_gap_agent, "Evaluate knowledge gaps")
490
+ builder.add_decision_node(
491
+ "continue_decision",
492
+ decision_function=lambda result: "writer"
493
+ if getattr(result, "research_complete", False)
494
+ else "tool_selector",
495
+ options=["tool_selector", "writer"],
496
+ description="Decide whether to continue research or write report",
497
+ )
498
+ builder.add_agent_node("tool_selector", tool_selector_agent, "Select tools to address gap")
499
+ builder.add_state_node(
500
+ "execute_tools",
501
+ state_updater=lambda state,
502
+ tasks: state, # Placeholder - actual execution handled separately
503
+ description="Execute selected tools",
504
+ )
505
+ builder.add_agent_node("writer", writer_agent, "Write final report")
506
+
507
+ # Add edges
508
+ builder.connect_nodes("thinking", "knowledge_gap")
509
+ builder.connect_nodes("knowledge_gap", "continue_decision")
510
+ builder.connect_nodes("continue_decision", "tool_selector")
511
+ builder.connect_nodes("continue_decision", "writer")
512
+ builder.connect_nodes("tool_selector", "execute_tools")
513
+ builder.connect_nodes("execute_tools", "thinking") # Loop back
514
+
515
+ # Set entry and exit
516
+ builder.set_entry_node("thinking")
517
+ builder.set_exit_nodes(["writer"])
518
+
519
+ return builder.build()
520
+
521
+
522
+ def create_deep_graph(
523
+ planner_agent: "Agent[Any, Any]",
524
+ knowledge_gap_agent: "Agent[Any, Any]",
525
+ tool_selector_agent: "Agent[Any, Any]",
526
+ thinking_agent: "Agent[Any, Any]",
527
+ writer_agent: "Agent[Any, Any]",
528
+ long_writer_agent: "Agent[Any, Any]",
529
+ ) -> ResearchGraph:
530
+ """Create a graph for deep research flow.
531
+
532
+ The graph structure: planner → store_plan → parallel_loops → collect_drafts → synthesizer
533
+
534
+ Args:
535
+ planner_agent: Agent for creating report plan
536
+ knowledge_gap_agent: Agent for evaluating knowledge gaps (not used directly, but needed for iterative flows)
537
+ tool_selector_agent: Agent for selecting tools (not used directly, but needed for iterative flows)
538
+ thinking_agent: Agent for generating observations (not used directly, but needed for iterative flows)
539
+ writer_agent: Agent for writing section reports (not used directly, but needed for iterative flows)
540
+ long_writer_agent: Agent for synthesizing final report
541
+
542
+ Returns:
543
+ Constructed ResearchGraph for deep research
544
+ """
545
+ from src.utils.models import ReportPlan
546
+
547
+ builder = GraphBuilder()
548
+
549
+ # Add nodes
550
+ # 1. Planner agent - creates report plan
551
+ builder.add_agent_node("planner", planner_agent, "Create report plan with sections")
552
+
553
+ # 2. State node - store report plan in workflow state
554
+ def store_plan(state: "WorkflowState", plan: ReportPlan) -> "WorkflowState":
555
+ """Store report plan in state for parallel loops to access."""
556
+ # Store plan in a custom attribute (we'll need to extend WorkflowState or use a dict)
557
+ # For now, we'll store it in the context's node_results
558
+ # The actual storage will happen in the graph execution
559
+ return state
560
+
561
+ builder.add_state_node(
562
+ "store_plan",
563
+ state_updater=store_plan,
564
+ description="Store report plan in state",
565
+ )
566
+
567
+ # 3. Parallel node - will execute iterative research flows for each section
568
+ # The actual execution will be handled dynamically in _execute_parallel_node()
569
+ # We use a special node ID that the executor will recognize
570
+ builder.add_parallel_node(
571
+ "parallel_loops",
572
+ parallel_nodes=[], # Will be populated dynamically based on report plan
573
+ description="Execute parallel iterative research loops for each section",
574
+ aggregator=lambda results: results, # Collect all section drafts
575
+ )
576
+
577
+ # 4. State node - collect section drafts into ReportDraft
578
+ def collect_drafts(state: "WorkflowState", section_drafts: list[str]) -> "WorkflowState":
579
+ """Collect section drafts into state for synthesizer."""
580
+ # Store drafts in state (will be accessed by synthesizer)
581
+ return state
582
+
583
+ builder.add_state_node(
584
+ "collect_drafts",
585
+ state_updater=collect_drafts,
586
+ description="Collect section drafts for synthesis",
587
+ )
588
+
589
+ # 5. Synthesizer agent - creates final report from drafts
590
+ builder.add_agent_node(
591
+ "synthesizer", long_writer_agent, "Synthesize final report from section drafts"
592
+ )
593
+
594
+ # Add edges
595
+ builder.connect_nodes("planner", "store_plan")
596
+ builder.connect_nodes("store_plan", "parallel_loops")
597
+ builder.connect_nodes("parallel_loops", "collect_drafts")
598
+ builder.connect_nodes("collect_drafts", "synthesizer")
599
+
600
+ # Set entry and exit
601
+ builder.set_entry_node("planner")
602
+ builder.set_exit_nodes(["synthesizer"])
603
+
604
+ return builder.build()
605
+
606
+
607
+ # No need to rebuild models since we're using Any types
608
+ # The models will work correctly with arbitrary_types_allowed=True
src/agent_factory/judges.py CHANGED
@@ -351,6 +351,15 @@ IMPORTANT: Respond with ONLY valid JSON matching this schema:
351
  )
352
 
353
 
 
 
 
 
 
 
 
 
 
354
  class MockJudgeHandler:
355
  """
356
  Mock JudgeHandler for demo mode without LLM calls.
 
351
  )
352
 
353
 
354
+ def create_judge_handler() -> JudgeHandler:
355
+ """Create a judge handler based on configuration.
356
+
357
+ Returns:
358
+ Configured JudgeHandler instance
359
+ """
360
+ return JudgeHandler()
361
+
362
+
363
  class MockJudgeHandler:
364
  """
365
  Mock JudgeHandler for demo mode without LLM calls.
src/agents/input_parser.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Input parser agent for analyzing and improving user queries.
2
+
3
+ Determines research mode (iterative vs deep) and extracts key information
4
+ from user queries to improve research quality.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING, Any, Literal
8
+
9
+ import structlog
10
+ from pydantic_ai import Agent
11
+
12
+ from src.agent_factory.judges import get_model
13
+ from src.utils.exceptions import ConfigurationError, JudgeError
14
+ from src.utils.models import ParsedQuery
15
+
16
+ if TYPE_CHECKING:
17
+ pass
18
+
19
+ logger = structlog.get_logger()
20
+
21
+ # System prompt for the input parser agent
22
+ SYSTEM_PROMPT = """
23
+ You are an expert research query analyzer. Your job is to analyze user queries and determine:
24
+ 1. Whether the query requires iterative research (single focused question) or deep research (multiple sections/topics)
25
+ 2. Improve and refine the query for better research results
26
+ 3. Extract key entities (drugs, diseases, targets, companies, etc.)
27
+ 4. Extract specific research questions
28
+
29
+ Guidelines for determining research mode:
30
+ - **Iterative mode**: Single focused question, straightforward research goal, can be answered with a focused search loop
31
+ Examples: "What is the mechanism of metformin?", "Find clinical trials for drug X"
32
+
33
+ - **Deep mode**: Complex query requiring multiple sections, comprehensive report, multiple related topics
34
+ Examples: "Write a comprehensive report on diabetes treatment", "Analyze the market for quantum computing"
35
+ Indicators: words like "comprehensive", "report", "sections", "analyze", "market analysis", "overview"
36
+
37
+ Your output must be valid JSON matching the ParsedQuery schema. Always provide:
38
+ - original_query: The exact input query
39
+ - improved_query: A refined, clearer version of the query
40
+ - research_mode: Either "iterative" or "deep"
41
+ - key_entities: List of important entities (drugs, diseases, companies, etc.)
42
+ - research_questions: List of specific questions to answer
43
+
44
+ Only output JSON. Do not output anything else.
45
+ """
46
+
47
+
48
+ class InputParserAgent:
49
+ """
50
+ Input parser agent that analyzes queries and determines research mode.
51
+
52
+ Uses Pydantic AI to generate structured ParsedQuery output with research
53
+ mode detection, query improvement, and entity extraction.
54
+ """
55
+
56
+ def __init__(self, model: Any | None = None) -> None:
57
+ """
58
+ Initialize the input parser agent.
59
+
60
+ Args:
61
+ model: Optional Pydantic AI model. If None, uses config default.
62
+ """
63
+ self.model = model or get_model()
64
+ self.logger = logger
65
+
66
+ # Initialize Pydantic AI Agent
67
+ self.agent = Agent(
68
+ model=self.model,
69
+ output_type=ParsedQuery,
70
+ system_prompt=SYSTEM_PROMPT,
71
+ retries=3,
72
+ )
73
+
74
+ async def parse(self, query: str) -> ParsedQuery:
75
+ """
76
+ Parse and analyze a user query.
77
+
78
+ Args:
79
+ query: The user's research query
80
+
81
+ Returns:
82
+ ParsedQuery with research mode, improved query, entities, and questions
83
+
84
+ Raises:
85
+ JudgeError: If parsing fails after retries
86
+ ConfigurationError: If agent configuration is invalid
87
+ """
88
+ self.logger.info("Parsing user query", query=query[:100])
89
+
90
+ user_message = f"QUERY: {query}"
91
+
92
+ try:
93
+ # Run the agent
94
+ result = await self.agent.run(user_message)
95
+ parsed_query = result.output
96
+
97
+ # Validate parsed query
98
+ if not parsed_query.original_query:
99
+ self.logger.warning("Parsed query missing original_query", query=query[:100])
100
+ raise JudgeError("Parsed query must have original_query")
101
+
102
+ if not parsed_query.improved_query:
103
+ self.logger.warning("Parsed query missing improved_query", query=query[:100])
104
+ # Use original as fallback
105
+ parsed_query = ParsedQuery(
106
+ original_query=parsed_query.original_query,
107
+ improved_query=parsed_query.original_query,
108
+ research_mode=parsed_query.research_mode,
109
+ key_entities=parsed_query.key_entities,
110
+ research_questions=parsed_query.research_questions,
111
+ )
112
+
113
+ self.logger.info(
114
+ "Query parsed successfully",
115
+ mode=parsed_query.research_mode,
116
+ entities=len(parsed_query.key_entities),
117
+ questions=len(parsed_query.research_questions),
118
+ )
119
+
120
+ return parsed_query
121
+
122
+ except Exception as e:
123
+ self.logger.error("Query parsing failed", error=str(e), query=query[:100])
124
+
125
+ # Fallback: return basic parsed query with heuristic mode detection
126
+ if isinstance(e, JudgeError | ConfigurationError):
127
+ raise
128
+
129
+ # Heuristic fallback
130
+ query_lower = query.lower()
131
+ research_mode: Literal["iterative", "deep"] = "iterative"
132
+ if any(
133
+ keyword in query_lower
134
+ for keyword in [
135
+ "comprehensive",
136
+ "report",
137
+ "sections",
138
+ "analyze",
139
+ "analysis",
140
+ "overview",
141
+ "market",
142
+ ]
143
+ ):
144
+ research_mode = "deep"
145
+
146
+ return ParsedQuery(
147
+ original_query=query,
148
+ improved_query=query,
149
+ research_mode=research_mode,
150
+ key_entities=[],
151
+ research_questions=[],
152
+ )
153
+
154
+
155
+ def create_input_parser_agent(model: Any | None = None) -> InputParserAgent:
156
+ """
157
+ Factory function to create an input parser agent.
158
+
159
+ Args:
160
+ model: Optional Pydantic AI model. If None, uses settings default.
161
+
162
+ Returns:
163
+ Configured InputParserAgent instance
164
+
165
+ Raises:
166
+ ConfigurationError: If required API keys are missing
167
+ """
168
+ try:
169
+ # Get model from settings if not provided
170
+ if model is None:
171
+ model = get_model()
172
+
173
+ # Create and return input parser agent
174
+ return InputParserAgent(model=model)
175
+
176
+ except Exception as e:
177
+ logger.error("Failed to create input parser agent", error=str(e))
178
+ raise ConfigurationError(f"Failed to create input parser agent: {e}") from e
src/agents/judge_agent.py CHANGED
@@ -12,7 +12,7 @@ from agent_framework import (
12
  Role,
13
  )
14
 
15
- from src.orchestrator import JudgeHandlerProtocol
16
  from src.utils.models import Evidence, JudgeAssessment
17
 
18
 
 
12
  Role,
13
  )
14
 
15
+ from src.legacy_orchestrator import JudgeHandlerProtocol
16
  from src.utils.models import Evidence, JudgeAssessment
17
 
18
 
src/agents/knowledge_gap.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Knowledge gap agent for evaluating research completeness.
2
+
3
+ Converts the folder/knowledge_gap_agent.py implementation to use Pydantic AI.
4
+ """
5
+
6
+ from datetime import datetime
7
+ from typing import Any
8
+
9
+ import structlog
10
+ from pydantic_ai import Agent
11
+
12
+ from src.agent_factory.judges import get_model
13
+ from src.utils.exceptions import ConfigurationError
14
+ from src.utils.models import KnowledgeGapOutput
15
+
16
+ logger = structlog.get_logger()
17
+
18
+
19
+ # System prompt for the knowledge gap agent
20
+ SYSTEM_PROMPT = f"""
21
+ You are a Research State Evaluator. Today's date is {datetime.now().strftime("%Y-%m-%d")}.
22
+ Your job is to critically analyze the current state of a research report,
23
+ identify what knowledge gaps still exist and determine the best next step to take.
24
+
25
+ You will be given:
26
+ 1. The original user query and any relevant background context to the query
27
+ 2. A full history of the tasks, actions, findings and thoughts you've made up until this point in the research process
28
+
29
+ Your task is to:
30
+ 1. Carefully review the findings and thoughts, particularly from the latest iteration, and assess their completeness in answering the original query
31
+ 2. Determine if the findings are sufficiently complete to end the research loop
32
+ 3. If not, identify up to 3 knowledge gaps that need to be addressed in sequence in order to continue with research - these should be relevant to the original query
33
+
34
+ Be specific in the gaps you identify and include relevant information as this will be passed onto another agent to process without additional context.
35
+
36
+ Only output JSON. Follow the JSON schema for KnowledgeGapOutput. Do not output anything else.
37
+ """
38
+
39
+
40
+ class KnowledgeGapAgent:
41
+ """
42
+ Agent that evaluates research state and identifies knowledge gaps.
43
+
44
+ Uses Pydantic AI to generate structured KnowledgeGapOutput indicating
45
+ whether research is complete and what gaps remain.
46
+ """
47
+
48
+ def __init__(self, model: Any | None = None) -> None:
49
+ """
50
+ Initialize the knowledge gap agent.
51
+
52
+ Args:
53
+ model: Optional Pydantic AI model. If None, uses config default.
54
+ """
55
+ self.model = model or get_model()
56
+ self.logger = logger
57
+
58
+ # Initialize Pydantic AI Agent
59
+ self.agent = Agent(
60
+ model=self.model,
61
+ output_type=KnowledgeGapOutput,
62
+ system_prompt=SYSTEM_PROMPT,
63
+ retries=3,
64
+ )
65
+
66
+ async def evaluate(
67
+ self,
68
+ query: str,
69
+ background_context: str = "",
70
+ conversation_history: str = "",
71
+ iteration: int = 0,
72
+ time_elapsed_minutes: float = 0.0,
73
+ max_time_minutes: int = 10,
74
+ ) -> KnowledgeGapOutput:
75
+ """
76
+ Evaluate research state and identify knowledge gaps.
77
+
78
+ Args:
79
+ query: The original research query
80
+ background_context: Optional background context
81
+ conversation_history: History of actions, findings, and thoughts
82
+ iteration: Current iteration number
83
+ time_elapsed_minutes: Time elapsed so far
84
+ max_time_minutes: Maximum time allowed
85
+
86
+ Returns:
87
+ KnowledgeGapOutput with research completeness and outstanding gaps
88
+
89
+ Raises:
90
+ JudgeError: If evaluation fails after retries
91
+ """
92
+ self.logger.info(
93
+ "Evaluating knowledge gaps",
94
+ query=query[:100],
95
+ iteration=iteration,
96
+ )
97
+
98
+ background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else ""
99
+
100
+ user_message = f"""
101
+ Current Iteration Number: {iteration}
102
+ Time Elapsed: {time_elapsed_minutes:.2f} minutes of maximum {max_time_minutes} minutes
103
+
104
+ ORIGINAL QUERY:
105
+ {query}
106
+
107
+ {background}
108
+
109
+ HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
110
+ {conversation_history or "No previous actions, findings or thoughts available."}
111
+ """
112
+
113
+ try:
114
+ # Run the agent
115
+ result = await self.agent.run(user_message)
116
+ evaluation = result.output
117
+
118
+ self.logger.info(
119
+ "Knowledge gap evaluation complete",
120
+ research_complete=evaluation.research_complete,
121
+ gaps_count=len(evaluation.outstanding_gaps),
122
+ )
123
+
124
+ return evaluation
125
+
126
+ except Exception as e:
127
+ self.logger.error("Knowledge gap evaluation failed", error=str(e))
128
+ # Return fallback: research not complete, suggest continuing
129
+ return KnowledgeGapOutput(
130
+ research_complete=False,
131
+ outstanding_gaps=[f"Continue research on: {query}"],
132
+ )
133
+
134
+
135
+ def create_knowledge_gap_agent(model: Any | None = None) -> KnowledgeGapAgent:
136
+ """
137
+ Factory function to create a knowledge gap agent.
138
+
139
+ Args:
140
+ model: Optional Pydantic AI model. If None, uses settings default.
141
+
142
+ Returns:
143
+ Configured KnowledgeGapAgent instance
144
+
145
+ Raises:
146
+ ConfigurationError: If required API keys are missing
147
+ """
148
+ try:
149
+ if model is None:
150
+ model = get_model()
151
+
152
+ return KnowledgeGapAgent(model=model)
153
+
154
+ except Exception as e:
155
+ logger.error("Failed to create knowledge gap agent", error=str(e))
156
+ raise ConfigurationError(f"Failed to create knowledge gap agent: {e}") from e
src/agents/long_writer.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Long writer agent for iteratively writing report sections.
2
+
3
+ Converts the folder/long_writer_agent.py implementation to use Pydantic AI.
4
+ """
5
+
6
+ import re
7
+ from datetime import datetime
8
+ from typing import Any
9
+
10
+ import structlog
11
+ from pydantic import BaseModel, Field
12
+ from pydantic_ai import Agent
13
+
14
+ from src.agent_factory.judges import get_model
15
+ from src.utils.exceptions import ConfigurationError
16
+ from src.utils.models import ReportDraft
17
+
18
+ logger = structlog.get_logger()
19
+
20
+
21
+ # LongWriterOutput model for structured output
22
+ class LongWriterOutput(BaseModel):
23
+ """Output from the long writer agent for a single section."""
24
+
25
+ next_section_markdown: str = Field(
26
+ description="The final draft of the next section in markdown format"
27
+ )
28
+ references: list[str] = Field(
29
+ description="A list of URLs and their corresponding reference numbers for the section"
30
+ )
31
+
32
+ model_config = {"frozen": True}
33
+
34
+
35
+ # System prompt for the long writer agent
36
+ SYSTEM_PROMPT = f"""
37
+ You are an expert report writer tasked with iteratively writing each section of a report.
38
+ Today's date is {datetime.now().strftime('%Y-%m-%d')}.
39
+ You will be provided with:
40
+ 1. The original research query
41
+ 2. A final draft of the report containing the table of contents and all sections written up until this point (in the first iteration there will be no sections written yet)
42
+ 3. A first draft of the next section of the report to be written
43
+
44
+ OBJECTIVE:
45
+ 1. Write a final draft of the next section of the report with numbered citations in square brackets in the body of the report
46
+ 2. Produce a list of references to be appended to the end of the report
47
+
48
+ CITATIONS/REFERENCES:
49
+ The citations should be in numerical order, written in numbered square brackets in the body of the report.
50
+ Separately, a list of all URLs and their corresponding reference numbers will be included at the end of the report.
51
+ Follow the example below for formatting.
52
+
53
+ LongWriterOutput(
54
+ next_section_markdown="The company specializes in IT consulting [1]. It operates in the software services market which is expected to grow at 10% per year [2].",
55
+ references=["[1] https://example.com/first-source-url", "[2] https://example.com/second-source-url"]
56
+ )
57
+
58
+ GUIDELINES:
59
+ - You can reformat and reorganize the flow of the content and headings within a section to flow logically, but DO NOT remove details that were included in the first draft
60
+ - Only remove text from the first draft if it is already mentioned earlier in the report, or if it should be covered in a later section per the table of contents
61
+ - Ensure the heading for the section matches the table of contents
62
+ - Format the final output and references section as markdown
63
+ - Do not include a title for the reference section, just a list of numbered references
64
+
65
+ Only output JSON. Follow the JSON schema for LongWriterOutput. Do not output anything else.
66
+ """
67
+
68
+
69
+ class LongWriterAgent:
70
+ """
71
+ Agent that iteratively writes report sections with proper citations.
72
+
73
+ Uses Pydantic AI to generate structured LongWriterOutput for each section.
74
+ """
75
+
76
+ def __init__(self, model: Any | None = None) -> None:
77
+ """
78
+ Initialize the long writer agent.
79
+
80
+ Args:
81
+ model: Optional Pydantic AI model. If None, uses config default.
82
+ """
83
+ self.model = model or get_model()
84
+ self.logger = logger
85
+
86
+ # Initialize Pydantic AI Agent
87
+ self.agent = Agent(
88
+ model=self.model,
89
+ output_type=LongWriterOutput,
90
+ system_prompt=SYSTEM_PROMPT,
91
+ retries=3,
92
+ )
93
+
94
+ async def write_next_section(
95
+ self,
96
+ original_query: str,
97
+ report_draft: str,
98
+ next_section_title: str,
99
+ next_section_draft: str,
100
+ ) -> LongWriterOutput:
101
+ """
102
+ Write the next section of the report.
103
+
104
+ Args:
105
+ original_query: The original research query
106
+ report_draft: Current report draft (all sections written so far)
107
+ next_section_title: Title of the section to write
108
+ next_section_draft: Draft content for the next section
109
+
110
+ Returns:
111
+ LongWriterOutput with formatted section and references
112
+
113
+ Raises:
114
+ ConfigurationError: If writing fails
115
+ """
116
+ # Input validation
117
+ if not original_query or not original_query.strip():
118
+ self.logger.warning("Empty query provided, using default")
119
+ original_query = "Research query"
120
+
121
+ if not next_section_title or not next_section_title.strip():
122
+ self.logger.warning("Empty section title provided, using default")
123
+ next_section_title = "Section"
124
+
125
+ if next_section_draft is None:
126
+ next_section_draft = ""
127
+
128
+ if report_draft is None:
129
+ report_draft = ""
130
+
131
+ # Truncate very long inputs
132
+ max_draft_length = 30000
133
+ if len(report_draft) > max_draft_length:
134
+ self.logger.warning(
135
+ "Report draft too long, truncating",
136
+ original_length=len(report_draft),
137
+ )
138
+ report_draft = report_draft[:max_draft_length] + "\n\n[Content truncated]"
139
+
140
+ if len(next_section_draft) > max_draft_length:
141
+ self.logger.warning(
142
+ "Section draft too long, truncating",
143
+ original_length=len(next_section_draft),
144
+ )
145
+ next_section_draft = next_section_draft[:max_draft_length] + "\n\n[Content truncated]"
146
+
147
+ self.logger.info(
148
+ "Writing next section",
149
+ section_title=next_section_title,
150
+ query=original_query[:100],
151
+ )
152
+
153
+ user_message = f"""
154
+ <ORIGINAL QUERY>
155
+ {original_query}
156
+ </ORIGINAL QUERY>
157
+
158
+ <CURRENT REPORT DRAFT>
159
+ {report_draft or "No draft yet"}
160
+ </CURRENT REPORT DRAFT>
161
+
162
+ <TITLE OF NEXT SECTION TO WRITE>
163
+ {next_section_title}
164
+ </TITLE OF NEXT SECTION TO WRITE>
165
+
166
+ <DRAFT OF NEXT SECTION>
167
+ {next_section_draft}
168
+ </DRAFT OF NEXT SECTION>
169
+ """
170
+
171
+ # Retry logic for transient failures
172
+ max_retries = 3
173
+ last_exception: Exception | None = None
174
+
175
+ for attempt in range(max_retries):
176
+ try:
177
+ # Run the agent
178
+ result = await self.agent.run(user_message)
179
+ output = result.output
180
+
181
+ # Validate output
182
+ if not output or not isinstance(output, LongWriterOutput):
183
+ raise ValueError("Invalid output format")
184
+
185
+ if not output.next_section_markdown or not output.next_section_markdown.strip():
186
+ self.logger.warning("Empty section generated, using fallback")
187
+ raise ValueError("Empty section generated")
188
+
189
+ self.logger.info(
190
+ "Section written",
191
+ section_title=next_section_title,
192
+ references_count=len(output.references),
193
+ attempt=attempt + 1,
194
+ )
195
+
196
+ return output
197
+
198
+ except (TimeoutError, ConnectionError) as e:
199
+ # Transient errors - retry
200
+ last_exception = e
201
+ if attempt < max_retries - 1:
202
+ self.logger.warning(
203
+ "Transient error, retrying",
204
+ error=str(e),
205
+ attempt=attempt + 1,
206
+ max_retries=max_retries,
207
+ )
208
+ continue
209
+ else:
210
+ self.logger.error("Max retries exceeded for transient error", error=str(e))
211
+ break
212
+
213
+ except Exception as e:
214
+ # Non-transient errors - don't retry
215
+ last_exception = e
216
+ self.logger.error(
217
+ "Section writing failed",
218
+ error=str(e),
219
+ error_type=type(e).__name__,
220
+ )
221
+ break
222
+
223
+ # Return fallback section if all attempts failed
224
+ self.logger.error(
225
+ "Section writing failed after all attempts",
226
+ error=str(last_exception) if last_exception else "Unknown error",
227
+ )
228
+ return LongWriterOutput(
229
+ next_section_markdown=f"## {next_section_title}\n\n{next_section_draft}",
230
+ references=[],
231
+ )
232
+
233
+ async def write_report(
234
+ self,
235
+ original_query: str,
236
+ report_title: str,
237
+ report_draft: ReportDraft,
238
+ ) -> str:
239
+ """
240
+ Write the final report by iteratively writing each section.
241
+
242
+ Args:
243
+ original_query: The original research query
244
+ report_title: Title of the report
245
+ report_draft: ReportDraft with all sections
246
+
247
+ Returns:
248
+ Complete markdown report string
249
+
250
+ Raises:
251
+ ConfigurationError: If writing fails
252
+ """
253
+ # Input validation
254
+ if not original_query or not original_query.strip():
255
+ self.logger.warning("Empty query provided, using default")
256
+ original_query = "Research query"
257
+
258
+ if not report_title or not report_title.strip():
259
+ self.logger.warning("Empty report title provided, using default")
260
+ report_title = "Research Report"
261
+
262
+ if not report_draft or not report_draft.sections:
263
+ self.logger.warning("Empty report draft provided, returning minimal report")
264
+ return f"# {report_title}\n\n## Query\n{original_query}\n\n*No sections available.*"
265
+
266
+ self.logger.info(
267
+ "Writing full report",
268
+ report_title=report_title,
269
+ sections_count=len(report_draft.sections),
270
+ )
271
+
272
+ # Initialize the final draft with title and table of contents
273
+ final_draft = (
274
+ f"# {report_title}\n\n## Table of Contents\n\n"
275
+ + "\n".join(
276
+ [
277
+ f"{i+1}. {section.section_title}"
278
+ for i, section in enumerate(report_draft.sections)
279
+ ]
280
+ )
281
+ + "\n\n"
282
+ )
283
+ all_references: list[str] = []
284
+
285
+ for section in report_draft.sections:
286
+ # Write each section
287
+ next_section_output = await self.write_next_section(
288
+ original_query,
289
+ final_draft,
290
+ section.section_title,
291
+ section.section_content,
292
+ )
293
+
294
+ # Reformat references and update section markdown
295
+ section_markdown, all_references = self._reformat_references(
296
+ next_section_output.next_section_markdown,
297
+ next_section_output.references,
298
+ all_references,
299
+ )
300
+
301
+ # Reformat section headings
302
+ section_markdown = self._reformat_section_headings(section_markdown)
303
+
304
+ # Add to final draft
305
+ final_draft += section_markdown + "\n\n"
306
+
307
+ # Add final references
308
+ final_draft += "## References:\n\n" + " \n".join(all_references)
309
+
310
+ self.logger.info("Full report written", length=len(final_draft))
311
+
312
+ return final_draft
313
+
314
+ def _reformat_references(
315
+ self,
316
+ section_markdown: str,
317
+ section_references: list[str],
318
+ all_references: list[str],
319
+ ) -> tuple[str, list[str]]:
320
+ """
321
+ Reformat references: re-number, de-duplicate, and update markdown.
322
+
323
+ Args:
324
+ section_markdown: Markdown content with inline references [1], [2]
325
+ section_references: List of references for this section
326
+ all_references: Accumulated references from previous sections
327
+
328
+ Returns:
329
+ Tuple of (updated markdown, updated all_references)
330
+ """
331
+
332
+ # Convert reference lists to maps (URL -> ref_num)
333
+ def convert_ref_list_to_map(ref_list: list[str]) -> dict[str, int]:
334
+ ref_map: dict[str, int] = {}
335
+ for ref in ref_list:
336
+ try:
337
+ # Parse "[1] https://example.com" format
338
+ parts = ref.split("]", 1)
339
+ if len(parts) == 2:
340
+ ref_num = int(parts[0].strip("["))
341
+ url = parts[1].strip()
342
+ ref_map[url] = ref_num
343
+ except (ValueError, IndexError):
344
+ logger.warning("Invalid reference format", ref=ref)
345
+ continue
346
+ return ref_map
347
+
348
+ section_ref_map = convert_ref_list_to_map(section_references)
349
+ report_ref_map = convert_ref_list_to_map(all_references)
350
+ section_to_report_ref_map: dict[int, int] = {}
351
+
352
+ report_urls = set(report_ref_map.keys())
353
+ ref_count = max(report_ref_map.values() or [0])
354
+
355
+ # Map section references to report references
356
+ for url, section_ref_num in section_ref_map.items():
357
+ if url in report_urls:
358
+ # URL already exists - reuse its reference number
359
+ section_to_report_ref_map[section_ref_num] = report_ref_map[url]
360
+ else:
361
+ # New URL - assign next reference number
362
+ ref_count += 1
363
+ section_to_report_ref_map[section_ref_num] = ref_count
364
+ all_references.append(f"[{ref_count}] {url}")
365
+
366
+ # Replace reference numbers in markdown
367
+ def replace_reference(match: re.Match[str]) -> str:
368
+ ref_num = int(match.group(1))
369
+ mapped_ref_num = section_to_report_ref_map.get(ref_num)
370
+ if mapped_ref_num:
371
+ return f"[{mapped_ref_num}]"
372
+ return ""
373
+
374
+ updated_markdown = re.sub(r"\[(\d+)\]", replace_reference, section_markdown)
375
+
376
+ return updated_markdown, all_references
377
+
378
+ def _reformat_section_headings(self, section_markdown: str) -> str:
379
+ """
380
+ Reformat section headings to be consistent (level-2 for main heading).
381
+
382
+ Args:
383
+ section_markdown: Markdown content with headings
384
+
385
+ Returns:
386
+ Updated markdown with adjusted heading levels
387
+ """
388
+ if not section_markdown.strip():
389
+ return section_markdown
390
+
391
+ # Find first heading level
392
+ first_heading_match = re.search(r"^(#+)\s", section_markdown, re.MULTILINE)
393
+ if not first_heading_match:
394
+ return section_markdown
395
+
396
+ # Calculate level adjustment needed (target is level 2)
397
+ first_heading_level = len(first_heading_match.group(1))
398
+ level_adjustment = 2 - first_heading_level
399
+
400
+ def adjust_heading_level(match: re.Match[str]) -> str:
401
+ hashes = match.group(1)
402
+ content = match.group(2)
403
+ new_level = max(2, len(hashes) + level_adjustment)
404
+ return "#" * new_level + " " + content
405
+
406
+ # Apply heading adjustment
407
+ return re.sub(r"^(#+)\s(.+)$", adjust_heading_level, section_markdown, flags=re.MULTILINE)
408
+
409
+
410
+ def create_long_writer_agent(model: Any | None = None) -> LongWriterAgent:
411
+ """
412
+ Factory function to create a long writer agent.
413
+
414
+ Args:
415
+ model: Optional Pydantic AI model. If None, uses settings default.
416
+
417
+ Returns:
418
+ Configured LongWriterAgent instance
419
+
420
+ Raises:
421
+ ConfigurationError: If required API keys are missing
422
+ """
423
+ try:
424
+ if model is None:
425
+ model = get_model()
426
+
427
+ return LongWriterAgent(model=model)
428
+
429
+ except Exception as e:
430
+ logger.error("Failed to create long writer agent", error=str(e))
431
+ raise ConfigurationError(f"Failed to create long writer agent: {e}") from e
src/agents/proofreader.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Proofreader agent for finalizing report drafts.
2
+
3
+ Converts the folder/proofreader_agent.py implementation to use Pydantic AI.
4
+ """
5
+
6
+ from datetime import datetime
7
+ from typing import Any
8
+
9
+ import structlog
10
+ from pydantic_ai import Agent
11
+
12
+ from src.agent_factory.judges import get_model
13
+ from src.utils.exceptions import ConfigurationError
14
+ from src.utils.models import ReportDraft
15
+
16
+ logger = structlog.get_logger()
17
+
18
+
19
+ # System prompt for the proofreader agent
20
+ SYSTEM_PROMPT = f"""
21
+ You are a research expert who proofreads and edits research reports.
22
+ Today's date is {datetime.now().strftime("%Y-%m-%d")}.
23
+
24
+ You are given:
25
+ 1. The original query topic for the report
26
+ 2. A first draft of the report in ReportDraft format containing each section in sequence
27
+
28
+ Your task is to:
29
+ 1. **Combine sections:** Concatenate the sections into a single string
30
+ 2. **Add section titles:** Add the section titles to the beginning of each section in markdown format, as well as a main title for the report
31
+ 3. **De-duplicate:** Remove duplicate content across sections to avoid repetition
32
+ 4. **Remove irrelevant sections:** If any sections or sub-sections are completely irrelevant to the query, remove them
33
+ 5. **Refine wording:** Edit the wording of the report to be polished, concise and punchy, but **without eliminating any detail** or large chunks of text
34
+ 6. **Add a summary:** Add a short report summary / outline to the beginning of the report to provide an overview of the sections and what is discussed
35
+ 7. **Preserve sources:** Preserve all sources / references - move the long list of references to the end of the report
36
+ 8. **Update reference numbers:** Continue to include reference numbers in square brackets ([1], [2], [3], etc.) in the main body of the report, but update the numbering to match the new order of references at the end of the report
37
+ 9. **Output final report:** Output the final report in markdown format (do not wrap it in a code block)
38
+
39
+ Guidelines:
40
+ - Do not add any new facts or data to the report
41
+ - Do not remove any content from the report unless it is very clearly wrong, contradictory or irrelevant
42
+ - Remove or reformat any redundant or excessive headings, and ensure that the final nesting of heading levels is correct
43
+ - Ensure that the final report flows well and has a logical structure
44
+ - Include all sources and references that are present in the final report
45
+ """
46
+
47
+
48
+ class ProofreaderAgent:
49
+ """
50
+ Agent that proofreads and finalizes report drafts.
51
+
52
+ Uses Pydantic AI to generate polished markdown reports from draft sections.
53
+ """
54
+
55
+ def __init__(self, model: Any | None = None) -> None:
56
+ """
57
+ Initialize the proofreader agent.
58
+
59
+ Args:
60
+ model: Optional Pydantic AI model. If None, uses config default.
61
+ """
62
+ self.model = model or get_model()
63
+ self.logger = logger
64
+
65
+ # Initialize Pydantic AI Agent (no structured output - returns markdown text)
66
+ self.agent = Agent(
67
+ model=self.model,
68
+ system_prompt=SYSTEM_PROMPT,
69
+ retries=3,
70
+ )
71
+
72
+ async def proofread(
73
+ self,
74
+ query: str,
75
+ report_draft: ReportDraft,
76
+ ) -> str:
77
+ """
78
+ Proofread and finalize a report draft.
79
+
80
+ Args:
81
+ query: The original research query
82
+ report_draft: ReportDraft with all sections
83
+
84
+ Returns:
85
+ Final polished markdown report string
86
+
87
+ Raises:
88
+ ConfigurationError: If proofreading fails
89
+ """
90
+ # Input validation
91
+ if not query or not query.strip():
92
+ self.logger.warning("Empty query provided, using default")
93
+ query = "Research query"
94
+
95
+ if not report_draft or not report_draft.sections:
96
+ self.logger.warning("Empty report draft provided, returning minimal report")
97
+ return f"# Research Report\n\n## Query\n{query}\n\n*No sections available.*"
98
+
99
+ # Validate section structure
100
+ valid_sections = []
101
+ for section in report_draft.sections:
102
+ if section.section_title and section.section_title.strip():
103
+ valid_sections.append(section)
104
+ else:
105
+ self.logger.warning("Skipping section with empty title")
106
+
107
+ if not valid_sections:
108
+ self.logger.warning("No valid sections in draft, returning minimal report")
109
+ return f"# Research Report\n\n## Query\n{query}\n\n*No valid sections available.*"
110
+
111
+ self.logger.info(
112
+ "Proofreading report",
113
+ query=query[:100],
114
+ sections_count=len(valid_sections),
115
+ )
116
+
117
+ # Create validated draft
118
+ validated_draft = ReportDraft(sections=valid_sections)
119
+
120
+ user_message = f"""
121
+ QUERY:
122
+ {query}
123
+
124
+ REPORT DRAFT:
125
+ {validated_draft.model_dump_json()}
126
+ """
127
+
128
+ # Retry logic for transient failures
129
+ max_retries = 3
130
+ last_exception: Exception | None = None
131
+
132
+ for attempt in range(max_retries):
133
+ try:
134
+ # Run the agent
135
+ result = await self.agent.run(user_message)
136
+ final_report = result.output
137
+
138
+ # Validate output
139
+ if not final_report or not final_report.strip():
140
+ self.logger.warning("Empty report generated, using fallback")
141
+ raise ValueError("Empty report generated")
142
+
143
+ self.logger.info("Report proofread", length=len(final_report), attempt=attempt + 1)
144
+
145
+ return final_report
146
+
147
+ except (TimeoutError, ConnectionError) as e:
148
+ # Transient errors - retry
149
+ last_exception = e
150
+ if attempt < max_retries - 1:
151
+ self.logger.warning(
152
+ "Transient error, retrying",
153
+ error=str(e),
154
+ attempt=attempt + 1,
155
+ max_retries=max_retries,
156
+ )
157
+ continue
158
+ else:
159
+ self.logger.error("Max retries exceeded for transient error", error=str(e))
160
+ break
161
+
162
+ except Exception as e:
163
+ # Non-transient errors - don't retry
164
+ last_exception = e
165
+ self.logger.error(
166
+ "Proofreading failed",
167
+ error=str(e),
168
+ error_type=type(e).__name__,
169
+ )
170
+ break
171
+
172
+ # Return fallback: combine sections manually
173
+ self.logger.error(
174
+ "Proofreading failed after all attempts",
175
+ error=str(last_exception) if last_exception else "Unknown error",
176
+ )
177
+ sections = [
178
+ f"## {section.section_title}\n\n{section.section_content or 'Content unavailable.'}"
179
+ for section in valid_sections
180
+ ]
181
+ return f"# Research Report\n\n## Query\n{query}\n\n" + "\n\n".join(sections)
182
+
183
+
184
+ def create_proofreader_agent(model: Any | None = None) -> ProofreaderAgent:
185
+ """
186
+ Factory function to create a proofreader agent.
187
+
188
+ Args:
189
+ model: Optional Pydantic AI model. If None, uses settings default.
190
+
191
+ Returns:
192
+ Configured ProofreaderAgent instance
193
+
194
+ Raises:
195
+ ConfigurationError: If required API keys are missing
196
+ """
197
+ try:
198
+ if model is None:
199
+ model = get_model()
200
+
201
+ return ProofreaderAgent(model=model)
202
+
203
+ except Exception as e:
204
+ logger.error("Failed to create proofreader agent", error=str(e))
205
+ raise ConfigurationError(f"Failed to create proofreader agent: {e}") from e
src/agents/search_agent.py CHANGED
@@ -10,7 +10,7 @@ from agent_framework import (
10
  Role,
11
  )
12
 
13
- from src.orchestrator import SearchHandlerProtocol
14
  from src.utils.models import Citation, Evidence, SearchResult
15
 
16
  if TYPE_CHECKING:
 
10
  Role,
11
  )
12
 
13
+ from src.legacy_orchestrator import SearchHandlerProtocol
14
  from src.utils.models import Citation, Evidence, SearchResult
15
 
16
  if TYPE_CHECKING:
src/agents/state.py CHANGED
@@ -1,9 +1,11 @@
1
  """Thread-safe state management for Magentic agents.
2
 
3
- Uses contextvars to ensure isolation between concurrent requests (e.g., multiple users
4
- searching simultaneously via Gradio).
 
5
  """
6
 
 
7
  from contextvars import ContextVar
8
  from typing import TYPE_CHECKING, Any
9
 
@@ -15,8 +17,20 @@ if TYPE_CHECKING:
15
  from src.services.embeddings import EmbeddingService
16
 
17
 
 
 
 
 
 
 
 
 
 
18
  class MagenticState(BaseModel):
19
- """Mutable state for a Magentic workflow session."""
 
 
 
20
 
21
  evidence: list[Evidence] = Field(default_factory=list)
22
  # Type as Any to avoid circular imports/runtime resolution issues
@@ -75,14 +89,22 @@ _magentic_state_var: ContextVar[MagenticState | None] = ContextVar("magentic_sta
75
 
76
 
77
  def init_magentic_state(embedding_service: "EmbeddingService | None" = None) -> MagenticState:
78
- """Initialize a new state for the current context."""
 
 
 
 
79
  state = MagenticState(embedding_service=embedding_service)
80
  _magentic_state_var.set(state)
81
  return state
82
 
83
 
84
  def get_magentic_state() -> MagenticState:
85
- """Get the current state. Raises RuntimeError if not initialized."""
 
 
 
 
86
  state = _magentic_state_var.get()
87
  if state is None:
88
  # Auto-initialize if missing (e.g. during tests or simple scripts)
 
1
  """Thread-safe state management for Magentic agents.
2
 
3
+ DEPRECATED: This module is deprecated. Use src.middleware.state_machine instead.
4
+
5
+ This file is kept for backward compatibility and will be removed in a future version.
6
  """
7
 
8
+ import warnings
9
  from contextvars import ContextVar
10
  from typing import TYPE_CHECKING, Any
11
 
 
17
  from src.services.embeddings import EmbeddingService
18
 
19
 
20
+ def _deprecation_warning() -> None:
21
+ """Emit deprecation warning for this module."""
22
+ warnings.warn(
23
+ "src.agents.state is deprecated. Use src.middleware.state_machine instead.",
24
+ DeprecationWarning,
25
+ stacklevel=3,
26
+ )
27
+
28
+
29
  class MagenticState(BaseModel):
30
+ """Mutable state for a Magentic workflow session.
31
+
32
+ DEPRECATED: Use WorkflowState from src.middleware.state_machine instead.
33
+ """
34
 
35
  evidence: list[Evidence] = Field(default_factory=list)
36
  # Type as Any to avoid circular imports/runtime resolution issues
 
89
 
90
 
91
  def init_magentic_state(embedding_service: "EmbeddingService | None" = None) -> MagenticState:
92
+ """Initialize a new state for the current context.
93
+
94
+ DEPRECATED: Use init_workflow_state from src.middleware.state_machine instead.
95
+ """
96
+ _deprecation_warning()
97
  state = MagenticState(embedding_service=embedding_service)
98
  _magentic_state_var.set(state)
99
  return state
100
 
101
 
102
  def get_magentic_state() -> MagenticState:
103
+ """Get the current state. Raises RuntimeError if not initialized.
104
+
105
+ DEPRECATED: Use get_workflow_state from src.middleware.state_machine instead.
106
+ """
107
+ _deprecation_warning()
108
  state = _magentic_state_var.get()
109
  if state is None:
110
  # Auto-initialize if missing (e.g. during tests or simple scripts)
src/agents/thinking.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Thinking agent for generating observations and reflections.
2
+
3
+ Converts the folder/thinking_agent.py implementation to use Pydantic AI.
4
+ """
5
+
6
+ from datetime import datetime
7
+ from typing import Any
8
+
9
+ import structlog
10
+ from pydantic_ai import Agent
11
+
12
+ from src.agent_factory.judges import get_model
13
+ from src.utils.exceptions import ConfigurationError
14
+
15
+ logger = structlog.get_logger()
16
+
17
+
18
+ # System prompt for the thinking agent
19
+ SYSTEM_PROMPT = f"""
20
+ You are a research expert who is managing a research process in iterations. Today's date is {datetime.now().strftime("%Y-%m-%d")}.
21
+
22
+ You are given:
23
+ 1. The original research query along with some supporting background context
24
+ 2. A history of the tasks, actions, findings and thoughts you've made up until this point in the research process (on iteration 1 you will be at the start of the research process, so this will be empty)
25
+
26
+ Your objective is to reflect on the research process so far and share your latest thoughts.
27
+
28
+ Specifically, your thoughts should include reflections on questions such as:
29
+ - What have you learned from the last iteration?
30
+ - What new areas would you like to explore next, or existing topics you'd like to go deeper into?
31
+ - Were you able to retrieve the information you were looking for in the last iteration?
32
+ - If not, should we change our approach or move to the next topic?
33
+ - Is there any info that is contradictory or conflicting?
34
+
35
+ Guidelines:
36
+ - Share your stream of consciousness on the above questions as raw text
37
+ - Keep your response concise and informal
38
+ - Focus most of your thoughts on the most recent iteration and how that influences this next iteration
39
+ - Our aim is to do very deep and thorough research - bear this in mind when reflecting on the research process
40
+ - DO NOT produce a draft of the final report. This is not your job.
41
+ - If this is the first iteration (i.e. no data from prior iterations), provide thoughts on what info we need to gather in the first iteration to get started
42
+ """
43
+
44
+
45
+ class ThinkingAgent:
46
+ """
47
+ Agent that generates observations and reflections on the research process.
48
+
49
+ Uses Pydantic AI to generate unstructured text observations about
50
+ the current state of research and next steps.
51
+ """
52
+
53
+ def __init__(self, model: Any | None = None) -> None:
54
+ """
55
+ Initialize the thinking agent.
56
+
57
+ Args:
58
+ model: Optional Pydantic AI model. If None, uses config default.
59
+ """
60
+ self.model = model or get_model()
61
+ self.logger = logger
62
+
63
+ # Initialize Pydantic AI Agent (no structured output - returns text)
64
+ self.agent = Agent(
65
+ model=self.model,
66
+ system_prompt=SYSTEM_PROMPT,
67
+ retries=3,
68
+ )
69
+
70
+ async def generate_observations(
71
+ self,
72
+ query: str,
73
+ background_context: str = "",
74
+ conversation_history: str = "",
75
+ iteration: int = 1,
76
+ ) -> str:
77
+ """
78
+ Generate observations about the research process.
79
+
80
+ Args:
81
+ query: The original research query
82
+ background_context: Optional background context
83
+ conversation_history: History of actions, findings, and thoughts
84
+ iteration: Current iteration number
85
+
86
+ Returns:
87
+ String containing observations and reflections
88
+
89
+ Raises:
90
+ ConfigurationError: If generation fails
91
+ """
92
+ self.logger.info(
93
+ "Generating observations",
94
+ query=query[:100],
95
+ iteration=iteration,
96
+ )
97
+
98
+ background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else ""
99
+
100
+ user_message = f"""
101
+ You are starting iteration {iteration} of your research process.
102
+
103
+ ORIGINAL QUERY:
104
+ {query}
105
+
106
+ {background}
107
+
108
+ HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
109
+ {conversation_history or "No previous actions, findings or thoughts available."}
110
+ """
111
+
112
+ try:
113
+ # Run the agent
114
+ result = await self.agent.run(user_message)
115
+ observations = result.output
116
+
117
+ self.logger.info("Observations generated", length=len(observations))
118
+
119
+ return observations
120
+
121
+ except Exception as e:
122
+ self.logger.error("Observation generation failed", error=str(e))
123
+ # Return fallback observations
124
+ return f"Starting iteration {iteration}. Need to gather information about: {query}"
125
+
126
+
127
+ def create_thinking_agent(model: Any | None = None) -> ThinkingAgent:
128
+ """
129
+ Factory function to create a thinking agent.
130
+
131
+ Args:
132
+ model: Optional Pydantic AI model. If None, uses settings default.
133
+
134
+ Returns:
135
+ Configured ThinkingAgent instance
136
+
137
+ Raises:
138
+ ConfigurationError: If required API keys are missing
139
+ """
140
+ try:
141
+ if model is None:
142
+ model = get_model()
143
+
144
+ return ThinkingAgent(model=model)
145
+
146
+ except Exception as e:
147
+ logger.error("Failed to create thinking agent", error=str(e))
148
+ raise ConfigurationError(f"Failed to create thinking agent: {e}") from e
src/agents/tool_selector.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tool selector agent for choosing which tools to use for knowledge gaps.
2
+
3
+ Converts the folder/tool_selector_agent.py implementation to use Pydantic AI.
4
+ """
5
+
6
+ from datetime import datetime
7
+ from typing import Any
8
+
9
+ import structlog
10
+ from pydantic_ai import Agent
11
+
12
+ from src.agent_factory.judges import get_model
13
+ from src.utils.exceptions import ConfigurationError
14
+ from src.utils.models import AgentSelectionPlan
15
+
16
+ logger = structlog.get_logger()
17
+
18
+
19
+ # System prompt for the tool selector agent
20
+ SYSTEM_PROMPT = f"""
21
+ You are a Tool Selector responsible for determining which specialized agents should address a knowledge gap in a research project.
22
+ Today's date is {datetime.now().strftime("%Y-%m-%d")}.
23
+
24
+ You will be given:
25
+ 1. The original user query
26
+ 2. A knowledge gap identified in the research
27
+ 3. A full history of the tasks, actions, findings and thoughts you've made up until this point in the research process
28
+
29
+ Your task is to decide:
30
+ 1. Which specialized agents are best suited to address the gap
31
+ 2. What specific queries should be given to the agents (keep this short - 3-6 words)
32
+
33
+ Available specialized agents:
34
+ - WebSearchAgent: General web search for broad topics (can be called multiple times with different queries)
35
+ - SiteCrawlerAgent: Crawl the pages of a specific website to retrieve information about it - use this if you want to find out something about a particular company, entity or product
36
+ - RAGAgent: Semantic search within previously collected evidence - use when you need to find information from evidence already gathered in this research session. Best for finding connections, summarizing collected evidence, or retrieving specific details from earlier findings.
37
+
38
+ Guidelines:
39
+ - Aim to call at most 3 agents at a time in your final output
40
+ - You can list the WebSearchAgent multiple times with different queries if needed to cover the full scope of the knowledge gap
41
+ - Be specific and concise (3-6 words) with the agent queries - they should target exactly what information is needed
42
+ - If you know the website or domain name of an entity being researched, always include it in the query
43
+ - Use RAGAgent when: (1) You need to search within evidence already collected, (2) You want to find connections between different findings, (3) You need to retrieve specific details from earlier research iterations
44
+ - Use WebSearchAgent or SiteCrawlerAgent when: (1) You need fresh information from the web, (2) You're starting a new research direction, (3) You need information not yet in the collected evidence
45
+ - If a gap doesn't clearly match any agent's capability, default to the WebSearchAgent
46
+ - Use the history of actions / tool calls as a guide - try not to repeat yourself if an approach didn't work previously
47
+
48
+ Only output JSON. Follow the JSON schema for AgentSelectionPlan. Do not output anything else.
49
+ """
50
+
51
+
52
+ class ToolSelectorAgent:
53
+ """
54
+ Agent that selects appropriate tools to address knowledge gaps.
55
+
56
+ Uses Pydantic AI to generate structured AgentSelectionPlan with
57
+ specific tasks for web search and crawl agents.
58
+ """
59
+
60
+ def __init__(self, model: Any | None = None) -> None:
61
+ """
62
+ Initialize the tool selector agent.
63
+
64
+ Args:
65
+ model: Optional Pydantic AI model. If None, uses config default.
66
+ """
67
+ self.model = model or get_model()
68
+ self.logger = logger
69
+
70
+ # Initialize Pydantic AI Agent
71
+ self.agent = Agent(
72
+ model=self.model,
73
+ output_type=AgentSelectionPlan,
74
+ system_prompt=SYSTEM_PROMPT,
75
+ retries=3,
76
+ )
77
+
78
+ async def select_tools(
79
+ self,
80
+ gap: str,
81
+ query: str,
82
+ background_context: str = "",
83
+ conversation_history: str = "",
84
+ ) -> AgentSelectionPlan:
85
+ """
86
+ Select tools to address a knowledge gap.
87
+
88
+ Args:
89
+ gap: The knowledge gap to address
90
+ query: The original research query
91
+ background_context: Optional background context
92
+ conversation_history: History of actions, findings, and thoughts
93
+
94
+ Returns:
95
+ AgentSelectionPlan with tasks for selected agents
96
+
97
+ Raises:
98
+ ConfigurationError: If selection fails
99
+ """
100
+ self.logger.info("Selecting tools for gap", gap=gap[:100], query=query[:100])
101
+
102
+ background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else ""
103
+
104
+ user_message = f"""
105
+ ORIGINAL QUERY:
106
+ {query}
107
+
108
+ KNOWLEDGE GAP TO ADDRESS:
109
+ {gap}
110
+
111
+ {background}
112
+
113
+ HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
114
+ {conversation_history or "No previous actions, findings or thoughts available."}
115
+ """
116
+
117
+ try:
118
+ # Run the agent
119
+ result = await self.agent.run(user_message)
120
+ selection_plan = result.output
121
+
122
+ self.logger.info(
123
+ "Tool selection complete",
124
+ tasks_count=len(selection_plan.tasks),
125
+ agents=[task.agent for task in selection_plan.tasks],
126
+ )
127
+
128
+ return selection_plan
129
+
130
+ except Exception as e:
131
+ self.logger.error("Tool selection failed", error=str(e))
132
+ # Return fallback: use web search
133
+ from src.utils.models import AgentTask
134
+
135
+ return AgentSelectionPlan(
136
+ tasks=[
137
+ AgentTask(
138
+ gap=gap,
139
+ agent="WebSearchAgent",
140
+ query=gap[:50], # Use gap as query
141
+ entity_website=None,
142
+ )
143
+ ]
144
+ )
145
+
146
+
147
+ def create_tool_selector_agent(model: Any | None = None) -> ToolSelectorAgent:
148
+ """
149
+ Factory function to create a tool selector agent.
150
+
151
+ Args:
152
+ model: Optional Pydantic AI model. If None, uses settings default.
153
+
154
+ Returns:
155
+ Configured ToolSelectorAgent instance
156
+
157
+ Raises:
158
+ ConfigurationError: If required API keys are missing
159
+ """
160
+ try:
161
+ if model is None:
162
+ model = get_model()
163
+
164
+ return ToolSelectorAgent(model=model)
165
+
166
+ except Exception as e:
167
+ logger.error("Failed to create tool selector agent", error=str(e))
168
+ raise ConfigurationError(f"Failed to create tool selector agent: {e}") from e
src/agents/writer.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Writer agent for generating final reports from findings.
2
+
3
+ Converts the folder/writer_agent.py implementation to use Pydantic AI.
4
+ """
5
+
6
+ from datetime import datetime
7
+ from typing import Any
8
+
9
+ import structlog
10
+ from pydantic_ai import Agent
11
+
12
+ from src.agent_factory.judges import get_model
13
+ from src.utils.exceptions import ConfigurationError
14
+
15
+ logger = structlog.get_logger()
16
+
17
+
18
+ # System prompt for the writer agent
19
+ SYSTEM_PROMPT = f"""
20
+ You are a senior researcher tasked with comprehensively answering a research query.
21
+ Today's date is {datetime.now().strftime('%Y-%m-%d')}.
22
+ You will be provided with the original query along with research findings put together by a research assistant.
23
+ Your objective is to generate the final response in markdown format.
24
+ The response should be as lengthy and detailed as possible with the information provided, focusing on answering the original query.
25
+ In your final output, include references to the source URLs for all information and data gathered.
26
+ This should be formatted in the form of a numbered square bracket next to the relevant information,
27
+ followed by a list of URLs at the end of the response, per the example below.
28
+
29
+ EXAMPLE REFERENCE FORMAT:
30
+ The company has XYZ products [1]. It operates in the software services market which is expected to grow at 10% per year [2].
31
+
32
+ References:
33
+ [1] https://example.com/first-source-url
34
+ [2] https://example.com/second-source-url
35
+
36
+ GUIDELINES:
37
+ * Answer the query directly, do not include unrelated or tangential information.
38
+ * Adhere to any instructions on the length of your final response if provided in the user prompt.
39
+ * If any additional guidelines are provided in the user prompt, follow them exactly and give them precedence over these system instructions.
40
+ """
41
+
42
+
43
+ class WriterAgent:
44
+ """
45
+ Agent that generates final reports from research findings.
46
+
47
+ Uses Pydantic AI to generate markdown reports with citations.
48
+ """
49
+
50
+ def __init__(self, model: Any | None = None) -> None:
51
+ """
52
+ Initialize the writer agent.
53
+
54
+ Args:
55
+ model: Optional Pydantic AI model. If None, uses config default.
56
+ """
57
+ self.model = model or get_model()
58
+ self.logger = logger
59
+
60
+ # Initialize Pydantic AI Agent (no structured output - returns markdown text)
61
+ self.agent = Agent(
62
+ model=self.model,
63
+ system_prompt=SYSTEM_PROMPT,
64
+ retries=3,
65
+ )
66
+
67
+ async def write_report(
68
+ self,
69
+ query: str,
70
+ findings: str,
71
+ output_length: str = "",
72
+ output_instructions: str = "",
73
+ ) -> str:
74
+ """
75
+ Write a final report from findings.
76
+
77
+ Args:
78
+ query: The original research query
79
+ findings: All findings collected during research
80
+ output_length: Optional description of desired output length
81
+ output_instructions: Optional additional instructions
82
+
83
+ Returns:
84
+ Markdown formatted report string
85
+
86
+ Raises:
87
+ ConfigurationError: If writing fails
88
+ """
89
+ # Input validation
90
+ if not query or not query.strip():
91
+ self.logger.warning("Empty query provided, using default")
92
+ query = "Research query"
93
+
94
+ if findings is None:
95
+ self.logger.warning("None findings provided, using empty string")
96
+ findings = "No findings available."
97
+
98
+ # Truncate very long inputs to prevent context overflow
99
+ max_findings_length = 50000 # ~12k tokens
100
+ if len(findings) > max_findings_length:
101
+ self.logger.warning(
102
+ "Findings too long, truncating",
103
+ original_length=len(findings),
104
+ truncated_length=max_findings_length,
105
+ )
106
+ findings = findings[:max_findings_length] + "\n\n[Content truncated due to length]"
107
+
108
+ self.logger.info("Writing final report", query=query[:100], findings_length=len(findings))
109
+
110
+ length_str = (
111
+ f"* The full response should be approximately {output_length}.\n"
112
+ if output_length
113
+ else ""
114
+ )
115
+ instructions_str = f"* {output_instructions}" if output_instructions else ""
116
+ guidelines_str = (
117
+ ("\n\nGUIDELINES:\n" + length_str + instructions_str).strip("\n")
118
+ if length_str or instructions_str
119
+ else ""
120
+ )
121
+
122
+ user_message = f"""
123
+ Provide a response based on the query and findings below with as much detail as possible. {guidelines_str}
124
+
125
+ QUERY: {query}
126
+
127
+ FINDINGS:
128
+ {findings}
129
+ """
130
+
131
+ # Retry logic for transient failures
132
+ max_retries = 3
133
+ last_exception: Exception | None = None
134
+
135
+ for attempt in range(max_retries):
136
+ try:
137
+ # Run the agent
138
+ result = await self.agent.run(user_message)
139
+ report = result.output
140
+
141
+ # Validate output
142
+ if not report or not report.strip():
143
+ self.logger.warning("Empty report generated, using fallback")
144
+ raise ValueError("Empty report generated")
145
+
146
+ self.logger.info("Report written", length=len(report), attempt=attempt + 1)
147
+
148
+ return report
149
+
150
+ except (TimeoutError, ConnectionError) as e:
151
+ # Transient errors - retry
152
+ last_exception = e
153
+ if attempt < max_retries - 1:
154
+ self.logger.warning(
155
+ "Transient error, retrying",
156
+ error=str(e),
157
+ attempt=attempt + 1,
158
+ max_retries=max_retries,
159
+ )
160
+ continue
161
+ else:
162
+ self.logger.error("Max retries exceeded for transient error", error=str(e))
163
+ break
164
+
165
+ except Exception as e:
166
+ # Non-transient errors - don't retry
167
+ last_exception = e
168
+ self.logger.error(
169
+ "Report writing failed", error=str(e), error_type=type(e).__name__
170
+ )
171
+ break
172
+
173
+ # Return fallback report if all attempts failed
174
+ self.logger.error(
175
+ "Report writing failed after all attempts",
176
+ error=str(last_exception) if last_exception else "Unknown error",
177
+ )
178
+ # Truncate findings in fallback if too long
179
+ fallback_findings = findings[:500] + "..." if len(findings) > 500 else findings
180
+ return (
181
+ f"# Research Report\n\n"
182
+ f"## Query\n{query}\n\n"
183
+ f"## Findings\n{fallback_findings}\n\n"
184
+ f"*Note: Report generation encountered an error. This is a fallback report.*"
185
+ )
186
+
187
+
188
+ def create_writer_agent(model: Any | None = None) -> WriterAgent:
189
+ """
190
+ Factory function to create a writer agent.
191
+
192
+ Args:
193
+ model: Optional Pydantic AI model. If None, uses settings default.
194
+
195
+ Returns:
196
+ Configured WriterAgent instance
197
+
198
+ Raises:
199
+ ConfigurationError: If required API keys are missing
200
+ """
201
+ try:
202
+ if model is None:
203
+ model = get_model()
204
+
205
+ return WriterAgent(model=model)
206
+
207
+ except Exception as e:
208
+ logger.error("Failed to create writer agent", error=str(e))
209
+ raise ConfigurationError(f"Failed to create writer agent: {e}") from e
src/{orchestrator.py → legacy_orchestrator.py} RENAMED
File without changes
src/middleware/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Middleware for workflow state management, parallel loop coordination, and budget tracking.
2
+
3
+ This module provides:
4
+ - WorkflowState: Thread-safe state management using ContextVar
5
+ - WorkflowManager: Coordination of parallel research loops
6
+ - BudgetTracker: Token, time, and iteration budget tracking
7
+ """
8
+
9
+ from src.middleware.budget_tracker import BudgetStatus, BudgetTracker
10
+ from src.middleware.state_machine import (
11
+ WorkflowState,
12
+ get_workflow_state,
13
+ init_workflow_state,
14
+ )
15
+ from src.middleware.workflow_manager import (
16
+ LoopStatus,
17
+ ResearchLoop,
18
+ WorkflowManager,
19
+ )
20
+
21
+ __all__ = [
22
+ # State management
23
+ "WorkflowState",
24
+ "init_workflow_state",
25
+ "get_workflow_state",
26
+ # Workflow management
27
+ "WorkflowManager",
28
+ "ResearchLoop",
29
+ "LoopStatus",
30
+ # Budget tracking
31
+ "BudgetTracker",
32
+ "BudgetStatus",
33
+ ]
src/middleware/budget_tracker.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Budget tracking for research loops.
2
+
3
+ Tracks token usage, time elapsed, and iteration counts per loop and globally.
4
+ Enforces budget constraints to prevent infinite loops and excessive resource usage.
5
+ """
6
+
7
+ import time
8
+
9
+ import structlog
10
+ from pydantic import BaseModel, Field
11
+
12
+ logger = structlog.get_logger()
13
+
14
+
15
+ class BudgetStatus(BaseModel):
16
+ """Status of a budget (tokens, time, iterations)."""
17
+
18
+ tokens_used: int = Field(default=0, description="Total tokens used")
19
+ tokens_limit: int = Field(default=100000, description="Token budget limit", ge=0)
20
+ time_elapsed_seconds: float = Field(default=0.0, description="Time elapsed", ge=0.0)
21
+ time_limit_seconds: float = Field(
22
+ default=600.0, description="Time budget limit (10 min default)", ge=0.0
23
+ )
24
+ iterations: int = Field(default=0, description="Number of iterations completed", ge=0)
25
+ iterations_limit: int = Field(default=10, description="Maximum iterations", ge=1)
26
+ iteration_tokens: dict[int, int] = Field(
27
+ default_factory=dict,
28
+ description="Tokens used per iteration (iteration number -> token count)",
29
+ )
30
+
31
+ def is_exceeded(self) -> bool:
32
+ """Check if any budget limit has been exceeded.
33
+
34
+ Returns:
35
+ True if any limit is exceeded, False otherwise.
36
+ """
37
+ return (
38
+ self.tokens_used >= self.tokens_limit
39
+ or self.time_elapsed_seconds >= self.time_limit_seconds
40
+ or self.iterations >= self.iterations_limit
41
+ )
42
+
43
+ def remaining_tokens(self) -> int:
44
+ """Get remaining token budget.
45
+
46
+ Returns:
47
+ Remaining tokens (may be negative if exceeded).
48
+ """
49
+ return self.tokens_limit - self.tokens_used
50
+
51
+ def remaining_time_seconds(self) -> float:
52
+ """Get remaining time budget.
53
+
54
+ Returns:
55
+ Remaining time in seconds (may be negative if exceeded).
56
+ """
57
+ return self.time_limit_seconds - self.time_elapsed_seconds
58
+
59
+ def remaining_iterations(self) -> int:
60
+ """Get remaining iteration budget.
61
+
62
+ Returns:
63
+ Remaining iterations (may be negative if exceeded).
64
+ """
65
+ return self.iterations_limit - self.iterations
66
+
67
+ def add_iteration_tokens(self, iteration: int, tokens: int) -> None:
68
+ """Add tokens for a specific iteration.
69
+
70
+ Args:
71
+ iteration: Iteration number (1-indexed).
72
+ tokens: Number of tokens to add.
73
+ """
74
+ if iteration not in self.iteration_tokens:
75
+ self.iteration_tokens[iteration] = 0
76
+ self.iteration_tokens[iteration] += tokens
77
+ # Also add to total tokens
78
+ self.tokens_used += tokens
79
+
80
+ def get_iteration_tokens(self, iteration: int) -> int:
81
+ """Get tokens used for a specific iteration.
82
+
83
+ Args:
84
+ iteration: Iteration number.
85
+
86
+ Returns:
87
+ Token count for the iteration, or 0 if not found.
88
+ """
89
+ return self.iteration_tokens.get(iteration, 0)
90
+
91
+
92
+ class BudgetTracker:
93
+ """Tracks budgets per loop and globally."""
94
+
95
+ def __init__(self) -> None:
96
+ """Initialize the budget tracker."""
97
+ self._budgets: dict[str, BudgetStatus] = {}
98
+ self._start_times: dict[str, float] = {}
99
+ self._global_budget: BudgetStatus | None = None
100
+
101
+ def create_budget(
102
+ self,
103
+ loop_id: str,
104
+ tokens_limit: int = 100000,
105
+ time_limit_seconds: float = 600.0,
106
+ iterations_limit: int = 10,
107
+ ) -> BudgetStatus:
108
+ """Create a budget for a specific loop.
109
+
110
+ Args:
111
+ loop_id: Unique identifier for the loop.
112
+ tokens_limit: Maximum tokens allowed.
113
+ time_limit_seconds: Maximum time allowed in seconds.
114
+ iterations_limit: Maximum iterations allowed.
115
+
116
+ Returns:
117
+ The created BudgetStatus instance.
118
+ """
119
+ budget = BudgetStatus(
120
+ tokens_limit=tokens_limit,
121
+ time_limit_seconds=time_limit_seconds,
122
+ iterations_limit=iterations_limit,
123
+ )
124
+ self._budgets[loop_id] = budget
125
+ logger.debug(
126
+ "Budget created",
127
+ loop_id=loop_id,
128
+ tokens_limit=tokens_limit,
129
+ time_limit=time_limit_seconds,
130
+ iterations_limit=iterations_limit,
131
+ )
132
+ return budget
133
+
134
+ def get_budget(self, loop_id: str) -> BudgetStatus | None:
135
+ """Get the budget for a specific loop.
136
+
137
+ Args:
138
+ loop_id: Unique identifier for the loop.
139
+
140
+ Returns:
141
+ The BudgetStatus instance, or None if not found.
142
+ """
143
+ return self._budgets.get(loop_id)
144
+
145
+ def add_tokens(self, loop_id: str, tokens: int) -> None:
146
+ """Add tokens to a loop's budget.
147
+
148
+ Args:
149
+ loop_id: Unique identifier for the loop.
150
+ tokens: Number of tokens to add (can be negative).
151
+ """
152
+ if loop_id not in self._budgets:
153
+ logger.warning("Budget not found for loop", loop_id=loop_id)
154
+ return
155
+ self._budgets[loop_id].tokens_used += tokens
156
+ logger.debug("Tokens added", loop_id=loop_id, tokens=tokens)
157
+
158
+ def add_iteration_tokens(self, loop_id: str, iteration: int, tokens: int) -> None:
159
+ """Add tokens for a specific iteration.
160
+
161
+ Args:
162
+ loop_id: Loop identifier.
163
+ iteration: Iteration number (1-indexed).
164
+ tokens: Number of tokens to add.
165
+ """
166
+ if loop_id not in self._budgets:
167
+ logger.warning("Budget not found for loop", loop_id=loop_id)
168
+ return
169
+
170
+ budget = self._budgets[loop_id]
171
+ budget.add_iteration_tokens(iteration, tokens)
172
+
173
+ logger.debug(
174
+ "Iteration tokens added",
175
+ loop_id=loop_id,
176
+ iteration=iteration,
177
+ tokens=tokens,
178
+ total_iteration=budget.get_iteration_tokens(iteration),
179
+ )
180
+
181
+ def get_iteration_tokens(self, loop_id: str, iteration: int) -> int:
182
+ """Get tokens used for a specific iteration.
183
+
184
+ Args:
185
+ loop_id: Loop identifier.
186
+ iteration: Iteration number.
187
+
188
+ Returns:
189
+ Token count for the iteration, or 0 if not found.
190
+ """
191
+ if loop_id not in self._budgets:
192
+ return 0
193
+
194
+ return self._budgets[loop_id].get_iteration_tokens(iteration)
195
+
196
+ def start_timer(self, loop_id: str) -> None:
197
+ """Start the timer for a loop.
198
+
199
+ Args:
200
+ loop_id: Unique identifier for the loop.
201
+ """
202
+ self._start_times[loop_id] = time.time()
203
+ logger.debug("Timer started", loop_id=loop_id)
204
+
205
+ def update_timer(self, loop_id: str) -> None:
206
+ """Update the elapsed time for a loop.
207
+
208
+ Args:
209
+ loop_id: Unique identifier for the loop.
210
+ """
211
+ if loop_id not in self._start_times:
212
+ logger.warning("Timer not started for loop", loop_id=loop_id)
213
+ return
214
+ if loop_id not in self._budgets:
215
+ logger.warning("Budget not found for loop", loop_id=loop_id)
216
+ return
217
+
218
+ elapsed = time.time() - self._start_times[loop_id]
219
+ self._budgets[loop_id].time_elapsed_seconds = elapsed
220
+ logger.debug("Timer updated", loop_id=loop_id, elapsed=elapsed)
221
+
222
+ def increment_iteration(self, loop_id: str) -> None:
223
+ """Increment the iteration count for a loop.
224
+
225
+ Args:
226
+ loop_id: Unique identifier for the loop.
227
+ """
228
+ if loop_id not in self._budgets:
229
+ logger.warning("Budget not found for loop", loop_id=loop_id)
230
+ return
231
+ self._budgets[loop_id].iterations += 1
232
+ logger.debug(
233
+ "Iteration incremented",
234
+ loop_id=loop_id,
235
+ iterations=self._budgets[loop_id].iterations,
236
+ )
237
+
238
+ def check_budget(self, loop_id: str) -> tuple[bool, str]:
239
+ """Check if a loop's budget has been exceeded.
240
+
241
+ Args:
242
+ loop_id: Unique identifier for the loop.
243
+
244
+ Returns:
245
+ Tuple of (exceeded: bool, reason: str). Reason is empty if not exceeded.
246
+ """
247
+ if loop_id not in self._budgets:
248
+ return False, ""
249
+
250
+ budget = self._budgets[loop_id]
251
+ self.update_timer(loop_id) # Update time before checking
252
+
253
+ if budget.is_exceeded():
254
+ reasons = []
255
+ if budget.tokens_used >= budget.tokens_limit:
256
+ reasons.append("tokens")
257
+ if budget.time_elapsed_seconds >= budget.time_limit_seconds:
258
+ reasons.append("time")
259
+ if budget.iterations >= budget.iterations_limit:
260
+ reasons.append("iterations")
261
+ reason = f"Budget exceeded: {', '.join(reasons)}"
262
+ logger.warning("Budget exceeded", loop_id=loop_id, reason=reason)
263
+ return True, reason
264
+
265
+ return False, ""
266
+
267
+ def can_continue(self, loop_id: str) -> bool:
268
+ """Check if a loop can continue based on budget.
269
+
270
+ Args:
271
+ loop_id: Unique identifier for the loop.
272
+
273
+ Returns:
274
+ True if the loop can continue, False if budget is exceeded.
275
+ """
276
+ exceeded, _ = self.check_budget(loop_id)
277
+ return not exceeded
278
+
279
+ def get_budget_summary(self, loop_id: str) -> str:
280
+ """Get a formatted summary of a loop's budget status.
281
+
282
+ Args:
283
+ loop_id: Unique identifier for the loop.
284
+
285
+ Returns:
286
+ Formatted string summary.
287
+ """
288
+ if loop_id not in self._budgets:
289
+ return f"Budget not found for loop: {loop_id}"
290
+
291
+ budget = self._budgets[loop_id]
292
+ self.update_timer(loop_id)
293
+
294
+ return (
295
+ f"Loop {loop_id}: "
296
+ f"Tokens: {budget.tokens_used}/{budget.tokens_limit} "
297
+ f"({budget.remaining_tokens()} remaining), "
298
+ f"Time: {budget.time_elapsed_seconds:.1f}/{budget.time_limit_seconds:.1f}s "
299
+ f"({budget.remaining_time_seconds():.1f}s remaining), "
300
+ f"Iterations: {budget.iterations}/{budget.iterations_limit} "
301
+ f"({budget.remaining_iterations()} remaining)"
302
+ )
303
+
304
+ def reset_budget(self, loop_id: str) -> None:
305
+ """Reset the budget for a loop.
306
+
307
+ Args:
308
+ loop_id: Unique identifier for the loop.
309
+ """
310
+ if loop_id in self._budgets:
311
+ old_budget = self._budgets[loop_id]
312
+ # Preserve iteration_tokens when resetting
313
+ old_iteration_tokens = old_budget.iteration_tokens
314
+ self._budgets[loop_id] = BudgetStatus(
315
+ tokens_limit=old_budget.tokens_limit,
316
+ time_limit_seconds=old_budget.time_limit_seconds,
317
+ iterations_limit=old_budget.iterations_limit,
318
+ iteration_tokens=old_iteration_tokens, # Restore old iteration tokens
319
+ )
320
+ if loop_id in self._start_times:
321
+ self._start_times[loop_id] = time.time()
322
+ logger.debug("Budget reset", loop_id=loop_id)
323
+
324
+ def set_global_budget(
325
+ self,
326
+ tokens_limit: int = 100000,
327
+ time_limit_seconds: float = 600.0,
328
+ iterations_limit: int = 10,
329
+ ) -> None:
330
+ """Set a global budget that applies to all loops.
331
+
332
+ Args:
333
+ tokens_limit: Maximum tokens allowed globally.
334
+ time_limit_seconds: Maximum time allowed in seconds.
335
+ iterations_limit: Maximum iterations allowed globally.
336
+ """
337
+ self._global_budget = BudgetStatus(
338
+ tokens_limit=tokens_limit,
339
+ time_limit_seconds=time_limit_seconds,
340
+ iterations_limit=iterations_limit,
341
+ )
342
+ logger.debug(
343
+ "Global budget set",
344
+ tokens_limit=tokens_limit,
345
+ time_limit=time_limit_seconds,
346
+ iterations_limit=iterations_limit,
347
+ )
348
+
349
+ def get_global_budget(self) -> BudgetStatus | None:
350
+ """Get the global budget.
351
+
352
+ Returns:
353
+ The global BudgetStatus instance, or None if not set.
354
+ """
355
+ return self._global_budget
356
+
357
+ def add_global_tokens(self, tokens: int) -> None:
358
+ """Add tokens to the global budget.
359
+
360
+ Args:
361
+ tokens: Number of tokens to add (can be negative).
362
+ """
363
+ if self._global_budget is None:
364
+ logger.warning("Global budget not set")
365
+ return
366
+ self._global_budget.tokens_used += tokens
367
+ logger.debug("Global tokens added", tokens=tokens)
368
+
369
+ def estimate_tokens(self, text: str) -> int:
370
+ """Estimate token count from text (rough estimate: ~4 chars per token).
371
+
372
+ Args:
373
+ text: Text to estimate tokens for.
374
+
375
+ Returns:
376
+ Estimated token count.
377
+ """
378
+ return len(text) // 4
379
+
380
+ def estimate_llm_call_tokens(self, prompt: str, response: str) -> int:
381
+ """Estimate token count for an LLM call.
382
+
383
+ Args:
384
+ prompt: The prompt text.
385
+ response: The response text.
386
+
387
+ Returns:
388
+ Estimated total token count (prompt + response).
389
+ """
390
+ return self.estimate_tokens(prompt) + self.estimate_tokens(response)
src/middleware/state_machine.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Thread-safe state management for workflow agents.
2
+
3
+ Uses contextvars to ensure isolation between concurrent requests (e.g., multiple users
4
+ searching simultaneously via Gradio). Refactored from MagenticState to support both
5
+ iterative and deep research patterns.
6
+ """
7
+
8
+ from contextvars import ContextVar
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ import structlog
12
+ from pydantic import BaseModel, Field
13
+
14
+ from src.utils.models import Citation, Conversation, Evidence
15
+
16
+ if TYPE_CHECKING:
17
+ from src.services.embeddings import EmbeddingService
18
+
19
+ logger = structlog.get_logger()
20
+
21
+
22
+ class WorkflowState(BaseModel):
23
+ """Mutable state for a workflow session.
24
+
25
+ Supports both iterative and deep research patterns by tracking evidence,
26
+ conversation history, and providing semantic search capabilities.
27
+ """
28
+
29
+ evidence: list[Evidence] = Field(default_factory=list)
30
+ conversation: Conversation = Field(default_factory=Conversation)
31
+ # Type as Any to avoid circular imports/runtime resolution issues
32
+ # The actual object injected will be an EmbeddingService instance
33
+ embedding_service: Any = Field(default=None)
34
+
35
+ model_config = {"arbitrary_types_allowed": True}
36
+
37
+ def add_evidence(self, new_evidence: list[Evidence]) -> int:
38
+ """Add new evidence, deduplicating by URL.
39
+
40
+ Args:
41
+ new_evidence: List of Evidence objects to add.
42
+
43
+ Returns:
44
+ Number of *new* items added (excluding duplicates).
45
+ """
46
+ existing_urls = {e.citation.url for e in self.evidence}
47
+ count = 0
48
+ for item in new_evidence:
49
+ if item.citation.url not in existing_urls:
50
+ self.evidence.append(item)
51
+ existing_urls.add(item.citation.url)
52
+ count += 1
53
+ return count
54
+
55
+ async def search_related(self, query: str, n_results: int = 5) -> list[Evidence]:
56
+ """Search for semantically related evidence using the embedding service.
57
+
58
+ Args:
59
+ query: Search query string.
60
+ n_results: Maximum number of results to return.
61
+
62
+ Returns:
63
+ List of Evidence objects, ordered by relevance.
64
+ """
65
+ if not self.embedding_service:
66
+ logger.warning("Embedding service not available, returning empty results")
67
+ return []
68
+
69
+ results = await self.embedding_service.search_similar(query, n_results=n_results)
70
+
71
+ # Convert dict results back to Evidence objects
72
+ evidence_list = []
73
+ for item in results:
74
+ meta = item.get("metadata", {})
75
+ authors_str = meta.get("authors", "")
76
+ authors = [a.strip() for a in authors_str.split(",") if a.strip()]
77
+
78
+ ev = Evidence(
79
+ content=item["content"],
80
+ citation=Citation(
81
+ title=meta.get("title", "Related Evidence"),
82
+ url=item["id"],
83
+ source="pubmed", # Defaulting to pubmed if unknown
84
+ date=meta.get("date", "n.d."),
85
+ authors=authors,
86
+ ),
87
+ relevance=max(0.0, 1.0 - item.get("distance", 0.5)),
88
+ )
89
+ evidence_list.append(ev)
90
+
91
+ return evidence_list
92
+
93
+
94
+ # The ContextVar holds the WorkflowState for the current execution context
95
+ _workflow_state_var: ContextVar[WorkflowState | None] = ContextVar("workflow_state", default=None)
96
+
97
+
98
+ def init_workflow_state(
99
+ embedding_service: "EmbeddingService | None" = None,
100
+ ) -> WorkflowState:
101
+ """Initialize a new state for the current context.
102
+
103
+ Args:
104
+ embedding_service: Optional embedding service for semantic search.
105
+
106
+ Returns:
107
+ The initialized WorkflowState instance.
108
+ """
109
+ state = WorkflowState(embedding_service=embedding_service)
110
+ _workflow_state_var.set(state)
111
+ logger.debug("Workflow state initialized", has_embeddings=embedding_service is not None)
112
+ return state
113
+
114
+
115
+ def get_workflow_state() -> WorkflowState:
116
+ """Get the current state. Auto-initializes if not set.
117
+
118
+ Returns:
119
+ The current WorkflowState instance.
120
+
121
+ Raises:
122
+ RuntimeError: If state is not initialized and auto-initialization fails.
123
+ """
124
+ state = _workflow_state_var.get()
125
+ if state is None:
126
+ # Auto-initialize if missing (e.g. during tests or simple scripts)
127
+ logger.debug("Workflow state not found, auto-initializing")
128
+ return init_workflow_state()
129
+ return state
src/middleware/workflow_manager.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Workflow manager for coordinating parallel research loops.
2
+
3
+ Manages multiple research loops running in parallel, tracks their status,
4
+ and synchronizes evidence between loops and the global state.
5
+ """
6
+
7
+ import asyncio
8
+ from collections.abc import Callable
9
+ from typing import Any, Literal
10
+
11
+ import structlog
12
+ from pydantic import BaseModel, Field
13
+
14
+ from src.middleware.state_machine import get_workflow_state
15
+ from src.utils.models import Evidence
16
+
17
+ logger = structlog.get_logger()
18
+
19
+ LoopStatus = Literal["pending", "running", "completed", "failed", "cancelled"]
20
+
21
+
22
+ class ResearchLoop(BaseModel):
23
+ """Represents a single research loop."""
24
+
25
+ loop_id: str = Field(description="Unique identifier for the loop")
26
+ query: str = Field(description="The research query for this loop")
27
+ status: LoopStatus = Field(default="pending")
28
+ evidence: list[Evidence] = Field(default_factory=list)
29
+ iteration_count: int = Field(default=0, ge=0)
30
+ error: str | None = Field(default=None)
31
+
32
+ model_config = {"frozen": False} # Mutable for status updates
33
+
34
+
35
+ class WorkflowManager:
36
+ """Manages parallel research loops and state synchronization."""
37
+
38
+ def __init__(self) -> None:
39
+ """Initialize the workflow manager."""
40
+ self._loops: dict[str, ResearchLoop] = {}
41
+
42
+ async def add_loop(self, loop_id: str, query: str) -> ResearchLoop:
43
+ """Add a new research loop.
44
+
45
+ Args:
46
+ loop_id: Unique identifier for the loop.
47
+ query: The research query for this loop.
48
+
49
+ Returns:
50
+ The created ResearchLoop instance.
51
+ """
52
+ loop = ResearchLoop(loop_id=loop_id, query=query, status="pending")
53
+ self._loops[loop_id] = loop
54
+ logger.info("Loop added", loop_id=loop_id, query=query)
55
+ return loop
56
+
57
+ async def get_loop(self, loop_id: str) -> ResearchLoop | None:
58
+ """Get a research loop by ID.
59
+
60
+ Args:
61
+ loop_id: Unique identifier for the loop.
62
+
63
+ Returns:
64
+ The ResearchLoop instance, or None if not found.
65
+ """
66
+ return self._loops.get(loop_id)
67
+
68
+ async def update_loop_status(
69
+ self, loop_id: str, status: LoopStatus, error: str | None = None
70
+ ) -> None:
71
+ """Update the status of a research loop.
72
+
73
+ Args:
74
+ loop_id: Unique identifier for the loop.
75
+ status: New status for the loop.
76
+ error: Optional error message if status is "failed".
77
+ """
78
+ if loop_id not in self._loops:
79
+ logger.warning("Loop not found", loop_id=loop_id)
80
+ return
81
+
82
+ self._loops[loop_id].status = status
83
+ if error:
84
+ self._loops[loop_id].error = error
85
+ logger.info("Loop status updated", loop_id=loop_id, status=status)
86
+
87
+ async def add_loop_evidence(self, loop_id: str, evidence: list[Evidence]) -> None:
88
+ """Add evidence to a research loop.
89
+
90
+ Args:
91
+ loop_id: Unique identifier for the loop.
92
+ evidence: List of Evidence objects to add.
93
+ """
94
+ if loop_id not in self._loops:
95
+ logger.warning("Loop not found", loop_id=loop_id)
96
+ return
97
+
98
+ self._loops[loop_id].evidence.extend(evidence)
99
+ logger.debug(
100
+ "Evidence added to loop",
101
+ loop_id=loop_id,
102
+ evidence_count=len(evidence),
103
+ )
104
+
105
+ async def increment_loop_iteration(self, loop_id: str) -> None:
106
+ """Increment the iteration count for a research loop.
107
+
108
+ Args:
109
+ loop_id: Unique identifier for the loop.
110
+ """
111
+ if loop_id not in self._loops:
112
+ logger.warning("Loop not found", loop_id=loop_id)
113
+ return
114
+
115
+ self._loops[loop_id].iteration_count += 1
116
+ logger.debug(
117
+ "Iteration incremented",
118
+ loop_id=loop_id,
119
+ iteration=self._loops[loop_id].iteration_count,
120
+ )
121
+
122
+ async def run_loops_parallel(
123
+ self,
124
+ loop_configs: list[dict[str, Any]],
125
+ loop_func: Callable[[dict[str, Any]], Any],
126
+ judge_handler: Any | None = None,
127
+ budget_tracker: Any | None = None,
128
+ ) -> list[Any]:
129
+ """Run multiple research loops in parallel.
130
+
131
+ Args:
132
+ loop_configs: List of configuration dicts, each must contain 'loop_id' and 'query'.
133
+ loop_func: Async function that takes a config dict and returns loop results.
134
+ judge_handler: Optional JudgeHandler for early termination based on evidence sufficiency.
135
+ budget_tracker: Optional BudgetTracker for budget enforcement.
136
+
137
+ Returns:
138
+ List of results from each loop (in order of completion, not original order).
139
+ """
140
+ logger.info("Starting parallel loops", loop_count=len(loop_configs))
141
+
142
+ # Create loops
143
+ for config in loop_configs:
144
+ loop_id = config.get("loop_id")
145
+ query = config.get("query", "")
146
+ if loop_id:
147
+ await self.add_loop(loop_id, query)
148
+ await self.update_loop_status(loop_id, "running")
149
+
150
+ # Run loops in parallel
151
+ async def run_single_loop(config: dict[str, Any]) -> Any:
152
+ loop_id = config.get("loop_id", "unknown")
153
+ query = config.get("query", "")
154
+ try:
155
+ # Check budget before starting
156
+ if budget_tracker:
157
+ exceeded, reason = budget_tracker.check_budget(loop_id)
158
+ if exceeded:
159
+ await self.update_loop_status(loop_id, "cancelled", error=reason)
160
+ logger.warning(
161
+ "Loop cancelled due to budget", loop_id=loop_id, reason=reason
162
+ )
163
+ return None
164
+
165
+ # If loop_func supports periodic checkpoints, we could check judge here
166
+ # For now, the loop_func itself handles judge checks internally
167
+ result = await loop_func(config)
168
+
169
+ # Final check with judge if available
170
+ if judge_handler and query:
171
+ should_complete, reason = await self.check_loop_completion(
172
+ loop_id, query, judge_handler
173
+ )
174
+ if should_complete:
175
+ logger.info(
176
+ "Loop completed early based on judge assessment",
177
+ loop_id=loop_id,
178
+ reason=reason,
179
+ )
180
+
181
+ await self.update_loop_status(loop_id, "completed")
182
+ return result
183
+ except Exception as e:
184
+ error_msg = str(e)
185
+ await self.update_loop_status(loop_id, "failed", error=error_msg)
186
+ logger.error("Loop failed", loop_id=loop_id, error=error_msg)
187
+ raise
188
+
189
+ results = await asyncio.gather(
190
+ *(run_single_loop(config) for config in loop_configs),
191
+ return_exceptions=True,
192
+ )
193
+
194
+ # Log completion
195
+ completed = sum(1 for r in results if not isinstance(r, Exception))
196
+ failed = len(results) - completed
197
+ logger.info(
198
+ "Parallel loops completed",
199
+ total=len(loop_configs),
200
+ completed=completed,
201
+ failed=failed,
202
+ )
203
+
204
+ return results
205
+
206
+ async def wait_for_loops(
207
+ self, loop_ids: list[str], timeout: float | None = None
208
+ ) -> list[ResearchLoop]:
209
+ """Wait for loops to complete.
210
+
211
+ Args:
212
+ loop_ids: List of loop IDs to wait for.
213
+ timeout: Optional timeout in seconds.
214
+
215
+ Returns:
216
+ List of ResearchLoop instances (may be incomplete if timeout occurs).
217
+ """
218
+ start_time = asyncio.get_event_loop().time()
219
+
220
+ while True:
221
+ loops = [self._loops.get(loop_id) for loop_id in loop_ids]
222
+ all_complete = all(
223
+ loop and loop.status in ("completed", "failed", "cancelled") for loop in loops
224
+ )
225
+
226
+ if all_complete:
227
+ return [loop for loop in loops if loop is not None]
228
+
229
+ if timeout is not None:
230
+ elapsed = asyncio.get_event_loop().time() - start_time
231
+ if elapsed >= timeout:
232
+ logger.warning("Timeout waiting for loops", timeout=timeout)
233
+ return [loop for loop in loops if loop is not None]
234
+
235
+ await asyncio.sleep(0.1) # Small delay to avoid busy waiting
236
+
237
+ async def cancel_loop(self, loop_id: str) -> None:
238
+ """Cancel a research loop.
239
+
240
+ Args:
241
+ loop_id: Unique identifier for the loop.
242
+ """
243
+ await self.update_loop_status(loop_id, "cancelled")
244
+ logger.info("Loop cancelled", loop_id=loop_id)
245
+
246
+ async def get_all_loops(self) -> list[ResearchLoop]:
247
+ """Get all research loops.
248
+
249
+ Returns:
250
+ List of all ResearchLoop instances.
251
+ """
252
+ return list(self._loops.values())
253
+
254
+ async def sync_loop_evidence_to_state(self, loop_id: str) -> None:
255
+ """Synchronize evidence from a loop to the global state.
256
+
257
+ Args:
258
+ loop_id: Unique identifier for the loop.
259
+ """
260
+ if loop_id not in self._loops:
261
+ logger.warning("Loop not found", loop_id=loop_id)
262
+ return
263
+
264
+ loop = self._loops[loop_id]
265
+ state = get_workflow_state()
266
+ added_count = state.add_evidence(loop.evidence)
267
+ logger.debug(
268
+ "Loop evidence synced to state",
269
+ loop_id=loop_id,
270
+ evidence_count=len(loop.evidence),
271
+ added_count=added_count,
272
+ )
273
+
274
+ async def get_shared_evidence(self) -> list[Evidence]:
275
+ """Get evidence from the global state.
276
+
277
+ Returns:
278
+ List of Evidence objects from the global state.
279
+ """
280
+ state = get_workflow_state()
281
+ return state.evidence
282
+
283
+ async def get_loop_evidence(self, loop_id: str) -> list[Evidence]:
284
+ """Get evidence collected by a specific loop.
285
+
286
+ Args:
287
+ loop_id: Loop identifier.
288
+
289
+ Returns:
290
+ List of Evidence objects from the loop.
291
+ """
292
+ if loop_id not in self._loops:
293
+ return []
294
+
295
+ return self._loops[loop_id].evidence
296
+
297
+ async def check_loop_completion(
298
+ self, loop_id: str, query: str, judge_handler: Any
299
+ ) -> tuple[bool, str]:
300
+ """Check if a loop should complete using judge assessment.
301
+
302
+ Args:
303
+ loop_id: Loop identifier.
304
+ query: Research query.
305
+ judge_handler: JudgeHandler instance.
306
+
307
+ Returns:
308
+ Tuple of (should_complete: bool, reason: str).
309
+ """
310
+ evidence = await self.get_loop_evidence(loop_id)
311
+
312
+ if not evidence:
313
+ return False, "No evidence collected yet"
314
+
315
+ try:
316
+ assessment = await judge_handler.assess(query, evidence)
317
+ if assessment.sufficient:
318
+ return True, f"Judge assessment: {assessment.reasoning}"
319
+ return False, f"Judge assessment: {assessment.reasoning}"
320
+ except Exception as e:
321
+ logger.error("Judge assessment failed", error=str(e), loop_id=loop_id)
322
+ return False, f"Judge assessment failed: {e!s}"
src/orchestrator/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Orchestrator module for research flows and planner agent.
2
+
3
+ This module provides:
4
+ - PlannerAgent: Creates report plans with sections
5
+ - IterativeResearchFlow: Single research loop pattern
6
+ - DeepResearchFlow: Parallel research loops pattern
7
+ - GraphOrchestrator: Stub for Phase 4 (uses agent chains for now)
8
+ - Protocols: SearchHandlerProtocol, JudgeHandlerProtocol (re-exported from legacy_orchestrator)
9
+ - Orchestrator: Legacy orchestrator class (re-exported from legacy_orchestrator)
10
+ """
11
+
12
+ from typing import TYPE_CHECKING
13
+
14
+ # Re-export protocols and Orchestrator from legacy_orchestrator for backward compatibility
15
+ from src.legacy_orchestrator import (
16
+ JudgeHandlerProtocol,
17
+ Orchestrator,
18
+ SearchHandlerProtocol,
19
+ )
20
+
21
+ # Lazy imports to avoid circular dependencies
22
+ if TYPE_CHECKING:
23
+ from src.orchestrator.graph_orchestrator import GraphOrchestrator
24
+ from src.orchestrator.planner_agent import PlannerAgent, create_planner_agent
25
+ from src.orchestrator.research_flow import (
26
+ DeepResearchFlow,
27
+ IterativeResearchFlow,
28
+ )
29
+
30
+ # Public exports
31
+ from src.orchestrator.graph_orchestrator import (
32
+ GraphOrchestrator,
33
+ create_graph_orchestrator,
34
+ )
35
+ from src.orchestrator.planner_agent import PlannerAgent, create_planner_agent
36
+ from src.orchestrator.research_flow import DeepResearchFlow, IterativeResearchFlow
37
+
38
+ __all__ = [
39
+ "PlannerAgent",
40
+ "create_planner_agent",
41
+ "IterativeResearchFlow",
42
+ "DeepResearchFlow",
43
+ "GraphOrchestrator",
44
+ "create_graph_orchestrator",
45
+ "SearchHandlerProtocol",
46
+ "JudgeHandlerProtocol",
47
+ "Orchestrator",
48
+ ]
src/orchestrator/graph_orchestrator.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Graph orchestrator for Phase 4.
2
+
3
+ Implements graph-based orchestration using Pydantic AI agents as nodes.
4
+ Supports both iterative and deep research patterns with parallel execution.
5
+ """
6
+
7
+ import asyncio
8
+ from collections.abc import AsyncGenerator, Callable
9
+ from typing import TYPE_CHECKING, Any, Literal
10
+
11
+ import structlog
12
+
13
+ from src.agent_factory.agents import (
14
+ create_input_parser_agent,
15
+ create_knowledge_gap_agent,
16
+ create_long_writer_agent,
17
+ create_planner_agent,
18
+ create_thinking_agent,
19
+ create_tool_selector_agent,
20
+ create_writer_agent,
21
+ )
22
+ from src.agent_factory.graph_builder import (
23
+ AgentNode,
24
+ DecisionNode,
25
+ ParallelNode,
26
+ ResearchGraph,
27
+ StateNode,
28
+ create_deep_graph,
29
+ create_iterative_graph,
30
+ )
31
+ from src.middleware.budget_tracker import BudgetTracker
32
+ from src.middleware.state_machine import WorkflowState, init_workflow_state
33
+ from src.orchestrator.research_flow import DeepResearchFlow, IterativeResearchFlow
34
+ from src.utils.models import AgentEvent
35
+
36
+ if TYPE_CHECKING:
37
+ pass
38
+
39
+ logger = structlog.get_logger()
40
+
41
+
42
+ class GraphExecutionContext:
43
+ """Context for managing graph execution state."""
44
+
45
+ def __init__(self, state: WorkflowState, budget_tracker: BudgetTracker) -> None:
46
+ """Initialize execution context.
47
+
48
+ Args:
49
+ state: Current workflow state
50
+ budget_tracker: Budget tracker instance
51
+ """
52
+ self.current_node: str = ""
53
+ self.visited_nodes: set[str] = set()
54
+ self.node_results: dict[str, Any] = {}
55
+ self.state = state
56
+ self.budget_tracker = budget_tracker
57
+ self.iteration_count = 0
58
+
59
+ def set_node_result(self, node_id: str, result: Any) -> None:
60
+ """Store result from node execution.
61
+
62
+ Args:
63
+ node_id: The node ID
64
+ result: The execution result
65
+ """
66
+ self.node_results[node_id] = result
67
+
68
+ def get_node_result(self, node_id: str) -> Any:
69
+ """Get result from node execution.
70
+
71
+ Args:
72
+ node_id: The node ID
73
+
74
+ Returns:
75
+ The stored result, or None if not found
76
+ """
77
+ return self.node_results.get(node_id)
78
+
79
+ def has_visited(self, node_id: str) -> bool:
80
+ """Check if node was visited.
81
+
82
+ Args:
83
+ node_id: The node ID
84
+
85
+ Returns:
86
+ True if visited, False otherwise
87
+ """
88
+ return node_id in self.visited_nodes
89
+
90
+ def mark_visited(self, node_id: str) -> None:
91
+ """Mark node as visited.
92
+
93
+ Args:
94
+ node_id: The node ID
95
+ """
96
+ self.visited_nodes.add(node_id)
97
+
98
+ def update_state(
99
+ self, updater: Callable[[WorkflowState, Any], WorkflowState], data: Any
100
+ ) -> None:
101
+ """Update workflow state.
102
+
103
+ Args:
104
+ updater: Function to update state
105
+ data: Data to pass to updater
106
+ """
107
+ self.state = updater(self.state, data)
108
+
109
+
110
+ class GraphOrchestrator:
111
+ """
112
+ Graph orchestrator using Pydantic AI Graphs.
113
+
114
+ Executes research workflows as graphs with nodes (agents) and edges (transitions).
115
+ Supports parallel execution, conditional routing, and state management.
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ mode: Literal["iterative", "deep", "auto"] = "auto",
121
+ max_iterations: int = 5,
122
+ max_time_minutes: int = 10,
123
+ use_graph: bool = True,
124
+ ) -> None:
125
+ """
126
+ Initialize graph orchestrator.
127
+
128
+ Args:
129
+ mode: Research mode ("iterative", "deep", or "auto" to detect)
130
+ max_iterations: Maximum iterations per loop
131
+ max_time_minutes: Maximum time per loop
132
+ use_graph: Whether to use graph execution (True) or agent chains (False)
133
+ """
134
+ self.mode = mode
135
+ self.max_iterations = max_iterations
136
+ self.max_time_minutes = max_time_minutes
137
+ self.use_graph = use_graph
138
+ self.logger = logger
139
+
140
+ # Initialize flows (for backward compatibility)
141
+ self._iterative_flow: IterativeResearchFlow | None = None
142
+ self._deep_flow: DeepResearchFlow | None = None
143
+
144
+ # Graph execution components (lazy initialization)
145
+ self._graph: ResearchGraph | None = None
146
+ self._budget_tracker: BudgetTracker | None = None
147
+
148
+ async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]:
149
+ """
150
+ Run the research workflow.
151
+
152
+ Args:
153
+ query: The user's research query
154
+
155
+ Yields:
156
+ AgentEvent objects for real-time UI updates
157
+ """
158
+ self.logger.info(
159
+ "Starting graph orchestrator",
160
+ query=query[:100],
161
+ mode=self.mode,
162
+ use_graph=self.use_graph,
163
+ )
164
+
165
+ yield AgentEvent(
166
+ type="started",
167
+ message=f"Starting research ({self.mode} mode): {query}",
168
+ iteration=0,
169
+ )
170
+
171
+ try:
172
+ # Determine research mode
173
+ research_mode = self.mode
174
+ if research_mode == "auto":
175
+ research_mode = await self._detect_research_mode(query)
176
+
177
+ # Use graph execution if enabled, otherwise fall back to agent chains
178
+ if self.use_graph:
179
+ async for event in self._run_with_graph(query, research_mode):
180
+ yield event
181
+ else:
182
+ async for event in self._run_with_chains(query, research_mode):
183
+ yield event
184
+
185
+ except Exception as e:
186
+ self.logger.error("Graph orchestrator failed", error=str(e), exc_info=True)
187
+ yield AgentEvent(
188
+ type="error",
189
+ message=f"Research failed: {e!s}",
190
+ iteration=0,
191
+ )
192
+
193
+ async def _run_with_graph(
194
+ self, query: str, research_mode: Literal["iterative", "deep"]
195
+ ) -> AsyncGenerator[AgentEvent, None]:
196
+ """Run workflow using graph execution.
197
+
198
+ Args:
199
+ query: The research query
200
+ research_mode: The research mode
201
+
202
+ Yields:
203
+ AgentEvent objects
204
+ """
205
+ # Initialize state and budget tracker
206
+ from src.services.embeddings import get_embedding_service
207
+
208
+ embedding_service = get_embedding_service()
209
+ state = init_workflow_state(embedding_service=embedding_service)
210
+ budget_tracker = BudgetTracker()
211
+ budget_tracker.create_budget(
212
+ loop_id="graph_execution",
213
+ tokens_limit=100000,
214
+ time_limit_seconds=self.max_time_minutes * 60,
215
+ iterations_limit=self.max_iterations,
216
+ )
217
+ budget_tracker.start_timer("graph_execution")
218
+
219
+ context = GraphExecutionContext(state, budget_tracker)
220
+
221
+ # Build graph
222
+ self._graph = await self._build_graph(research_mode)
223
+
224
+ # Execute graph
225
+ async for event in self._execute_graph(query, context):
226
+ yield event
227
+
228
+ async def _run_with_chains(
229
+ self, query: str, research_mode: Literal["iterative", "deep"]
230
+ ) -> AsyncGenerator[AgentEvent, None]:
231
+ """Run workflow using agent chains (backward compatibility).
232
+
233
+ Args:
234
+ query: The research query
235
+ research_mode: The research mode
236
+
237
+ Yields:
238
+ AgentEvent objects
239
+ """
240
+ if research_mode == "iterative":
241
+ yield AgentEvent(
242
+ type="searching",
243
+ message="Running iterative research flow...",
244
+ iteration=1,
245
+ )
246
+
247
+ if self._iterative_flow is None:
248
+ self._iterative_flow = IterativeResearchFlow(
249
+ max_iterations=self.max_iterations,
250
+ max_time_minutes=self.max_time_minutes,
251
+ )
252
+
253
+ final_report = await self._iterative_flow.run(query)
254
+
255
+ yield AgentEvent(
256
+ type="complete",
257
+ message=final_report,
258
+ data={"mode": "iterative"},
259
+ iteration=1,
260
+ )
261
+
262
+ elif research_mode == "deep":
263
+ yield AgentEvent(
264
+ type="searching",
265
+ message="Running deep research flow...",
266
+ iteration=1,
267
+ )
268
+
269
+ if self._deep_flow is None:
270
+ self._deep_flow = DeepResearchFlow(
271
+ max_iterations=self.max_iterations,
272
+ max_time_minutes=self.max_time_minutes,
273
+ )
274
+
275
+ final_report = await self._deep_flow.run(query)
276
+
277
+ yield AgentEvent(
278
+ type="complete",
279
+ message=final_report,
280
+ data={"mode": "deep"},
281
+ iteration=1,
282
+ )
283
+
284
+ async def _build_graph(self, mode: Literal["iterative", "deep"]) -> ResearchGraph:
285
+ """Build graph for the specified mode.
286
+
287
+ Args:
288
+ mode: Research mode
289
+
290
+ Returns:
291
+ Constructed ResearchGraph
292
+ """
293
+ if mode == "iterative":
294
+ # Get agents
295
+ knowledge_gap_agent = create_knowledge_gap_agent()
296
+ tool_selector_agent = create_tool_selector_agent()
297
+ thinking_agent = create_thinking_agent()
298
+ writer_agent = create_writer_agent()
299
+
300
+ # Create graph
301
+ graph = create_iterative_graph(
302
+ knowledge_gap_agent=knowledge_gap_agent.agent,
303
+ tool_selector_agent=tool_selector_agent.agent,
304
+ thinking_agent=thinking_agent.agent,
305
+ writer_agent=writer_agent.agent,
306
+ )
307
+ else: # deep
308
+ # Get agents
309
+ planner_agent = create_planner_agent()
310
+ knowledge_gap_agent = create_knowledge_gap_agent()
311
+ tool_selector_agent = create_tool_selector_agent()
312
+ thinking_agent = create_thinking_agent()
313
+ writer_agent = create_writer_agent()
314
+ long_writer_agent = create_long_writer_agent()
315
+
316
+ # Create graph
317
+ graph = create_deep_graph(
318
+ planner_agent=planner_agent.agent,
319
+ knowledge_gap_agent=knowledge_gap_agent.agent,
320
+ tool_selector_agent=tool_selector_agent.agent,
321
+ thinking_agent=thinking_agent.agent,
322
+ writer_agent=writer_agent.agent,
323
+ long_writer_agent=long_writer_agent.agent,
324
+ )
325
+
326
+ return graph
327
+
328
+ def _emit_start_event(
329
+ self, node: Any, current_node_id: str, iteration: int, context: GraphExecutionContext
330
+ ) -> AgentEvent:
331
+ """Emit start event for a node.
332
+
333
+ Args:
334
+ node: The node being executed
335
+ current_node_id: Current node ID
336
+ iteration: Current iteration number
337
+ context: Execution context
338
+
339
+ Returns:
340
+ AgentEvent for the start of node execution
341
+ """
342
+ if node and node.node_id == "planner":
343
+ return AgentEvent(
344
+ type="searching",
345
+ message="Creating report plan...",
346
+ iteration=iteration,
347
+ )
348
+ elif node and node.node_id == "parallel_loops":
349
+ # Get report plan to show section count
350
+ report_plan = context.get_node_result("planner")
351
+ if report_plan and hasattr(report_plan, "report_outline"):
352
+ section_count = len(report_plan.report_outline)
353
+ return AgentEvent(
354
+ type="looping",
355
+ message=f"Running parallel research loops for {section_count} sections...",
356
+ iteration=iteration,
357
+ data={"sections": section_count},
358
+ )
359
+ return AgentEvent(
360
+ type="looping",
361
+ message="Running parallel research loops...",
362
+ iteration=iteration,
363
+ )
364
+ elif node and node.node_id == "synthesizer":
365
+ return AgentEvent(
366
+ type="synthesizing",
367
+ message="Synthesizing final report from section drafts...",
368
+ iteration=iteration,
369
+ )
370
+ return AgentEvent(
371
+ type="looping",
372
+ message=f"Executing node: {current_node_id}",
373
+ iteration=iteration,
374
+ )
375
+
376
+ def _emit_completion_event(
377
+ self, node: Any, current_node_id: str, result: Any, iteration: int
378
+ ) -> AgentEvent:
379
+ """Emit completion event for a node.
380
+
381
+ Args:
382
+ node: The node that was executed
383
+ current_node_id: Current node ID
384
+ result: Node execution result
385
+ iteration: Current iteration number
386
+
387
+ Returns:
388
+ AgentEvent for the completion of node execution
389
+ """
390
+ if not node:
391
+ return AgentEvent(
392
+ type="looping",
393
+ message=f"Completed node: {current_node_id}",
394
+ iteration=iteration,
395
+ )
396
+
397
+ if node.node_id == "planner":
398
+ if isinstance(result, dict) and "report_outline" in result:
399
+ section_count = len(result["report_outline"])
400
+ return AgentEvent(
401
+ type="search_complete",
402
+ message=f"Report plan created with {section_count} sections",
403
+ iteration=iteration,
404
+ data={"sections": section_count},
405
+ )
406
+ return AgentEvent(
407
+ type="search_complete",
408
+ message="Report plan created",
409
+ iteration=iteration,
410
+ )
411
+ elif node.node_id == "parallel_loops":
412
+ if isinstance(result, list):
413
+ return AgentEvent(
414
+ type="search_complete",
415
+ message=f"Completed parallel research for {len(result)} sections",
416
+ iteration=iteration,
417
+ data={"sections_completed": len(result)},
418
+ )
419
+ return AgentEvent(
420
+ type="search_complete",
421
+ message="Parallel research loops completed",
422
+ iteration=iteration,
423
+ )
424
+ elif node.node_id == "synthesizer":
425
+ return AgentEvent(
426
+ type="synthesizing",
427
+ message="Final report synthesis completed",
428
+ iteration=iteration,
429
+ )
430
+ return AgentEvent(
431
+ type="searching" if node.node_type == "agent" else "looping",
432
+ message=f"Completed {node.node_type} node: {current_node_id}",
433
+ iteration=iteration,
434
+ )
435
+
436
+ async def _execute_graph(
437
+ self, query: str, context: GraphExecutionContext
438
+ ) -> AsyncGenerator[AgentEvent, None]:
439
+ """Execute the graph from entry node.
440
+
441
+ Args:
442
+ query: The research query
443
+ context: Execution context
444
+
445
+ Yields:
446
+ AgentEvent objects
447
+ """
448
+ if not self._graph:
449
+ raise ValueError("Graph not built")
450
+
451
+ current_node_id = self._graph.entry_node
452
+ iteration = 0
453
+
454
+ while current_node_id and current_node_id not in self._graph.exit_nodes:
455
+ # Check budget
456
+ if not context.budget_tracker.can_continue("graph_execution"):
457
+ self.logger.warning("Budget exceeded, exiting graph execution")
458
+ break
459
+
460
+ # Execute current node
461
+ iteration += 1
462
+ context.current_node = current_node_id
463
+ node = self._graph.get_node(current_node_id)
464
+
465
+ # Emit start event
466
+ yield self._emit_start_event(node, current_node_id, iteration, context)
467
+
468
+ try:
469
+ result = await self._execute_node(current_node_id, query, context)
470
+ context.set_node_result(current_node_id, result)
471
+ context.mark_visited(current_node_id)
472
+
473
+ # Yield completion event
474
+ yield self._emit_completion_event(node, current_node_id, result, iteration)
475
+
476
+ except Exception as e:
477
+ self.logger.error("Node execution failed", node_id=current_node_id, error=str(e))
478
+ yield AgentEvent(
479
+ type="error",
480
+ message=f"Node {current_node_id} failed: {e!s}",
481
+ iteration=iteration,
482
+ )
483
+ break
484
+
485
+ # Get next node(s)
486
+ next_nodes = self._get_next_node(current_node_id, context)
487
+
488
+ if not next_nodes:
489
+ # No more nodes, check if we're at exit
490
+ if current_node_id in self._graph.exit_nodes:
491
+ break
492
+ # Otherwise, we've reached a dead end
493
+ self.logger.warning("Reached dead end in graph", node_id=current_node_id)
494
+ break
495
+
496
+ current_node_id = next_nodes[0] # For now, take first next node (handle parallel later)
497
+
498
+ # Final event
499
+ final_result = context.get_node_result(current_node_id) if current_node_id else None
500
+ yield AgentEvent(
501
+ type="complete",
502
+ message=final_result if isinstance(final_result, str) else "Research completed",
503
+ data={"mode": self.mode, "iterations": iteration},
504
+ iteration=iteration,
505
+ )
506
+
507
+ async def _execute_node(self, node_id: str, query: str, context: GraphExecutionContext) -> Any:
508
+ """Execute a single node.
509
+
510
+ Args:
511
+ node_id: The node ID
512
+ query: The research query
513
+ context: Execution context
514
+
515
+ Returns:
516
+ Node execution result
517
+ """
518
+ if not self._graph:
519
+ raise ValueError("Graph not built")
520
+
521
+ node = self._graph.get_node(node_id)
522
+ if not node:
523
+ raise ValueError(f"Node {node_id} not found")
524
+
525
+ if isinstance(node, AgentNode):
526
+ return await self._execute_agent_node(node, query, context)
527
+ elif isinstance(node, StateNode):
528
+ return await self._execute_state_node(node, query, context)
529
+ elif isinstance(node, DecisionNode):
530
+ return await self._execute_decision_node(node, query, context)
531
+ elif isinstance(node, ParallelNode):
532
+ return await self._execute_parallel_node(node, query, context)
533
+ else:
534
+ raise ValueError(f"Unknown node type: {type(node)}")
535
+
536
+ async def _execute_agent_node(
537
+ self, node: AgentNode, query: str, context: GraphExecutionContext
538
+ ) -> Any:
539
+ """Execute an agent node.
540
+
541
+ Special handling for deep research nodes:
542
+ - "planner": Takes query string, returns ReportPlan
543
+ - "synthesizer": Takes query + ReportPlan + section drafts, returns final report
544
+
545
+ Args:
546
+ node: The agent node
547
+ query: The research query
548
+ context: Execution context
549
+
550
+ Returns:
551
+ Agent execution result
552
+ """
553
+ # Special handling for synthesizer node
554
+ if node.node_id == "synthesizer":
555
+ # Call LongWriterAgent.write_report() directly instead of using agent.run()
556
+ from src.agent_factory.agents import create_long_writer_agent
557
+ from src.utils.models import ReportDraft, ReportDraftSection, ReportPlan
558
+
559
+ report_plan = context.get_node_result("planner")
560
+ section_drafts = context.get_node_result("parallel_loops") or []
561
+
562
+ if not isinstance(report_plan, ReportPlan):
563
+ raise ValueError("ReportPlan not found for synthesizer")
564
+
565
+ if not section_drafts:
566
+ raise ValueError("Section drafts not found for synthesizer")
567
+
568
+ # Create ReportDraft from section drafts
569
+ report_draft = ReportDraft(
570
+ sections=[
571
+ ReportDraftSection(
572
+ section_title=section.title,
573
+ section_content=draft,
574
+ )
575
+ for section, draft in zip(
576
+ report_plan.report_outline, section_drafts, strict=False
577
+ )
578
+ ]
579
+ )
580
+
581
+ # Get LongWriterAgent instance and call write_report directly
582
+ long_writer_agent = create_long_writer_agent()
583
+ final_report = await long_writer_agent.write_report(
584
+ original_query=query,
585
+ report_title=report_plan.report_title,
586
+ report_draft=report_draft,
587
+ )
588
+
589
+ # Estimate tokens (rough estimate)
590
+ estimated_tokens = len(final_report) // 4 # Rough token estimate
591
+ context.budget_tracker.add_tokens("graph_execution", estimated_tokens)
592
+
593
+ return final_report
594
+
595
+ # Standard agent execution
596
+ # Prepare input based on node type
597
+ if node.node_id == "planner":
598
+ # Planner takes the original query
599
+ input_data = query
600
+ else:
601
+ # Standard: use previous node result or query
602
+ prev_result = context.get_node_result(context.current_node)
603
+ input_data = prev_result if prev_result is not None else query
604
+
605
+ # Apply input transformer if provided
606
+ if node.input_transformer:
607
+ input_data = node.input_transformer(input_data)
608
+
609
+ # Execute agent
610
+ result = await node.agent.run(input_data)
611
+
612
+ # Transform output if needed
613
+ output = result.output
614
+ if node.output_transformer:
615
+ output = node.output_transformer(output)
616
+
617
+ # Estimate and track tokens
618
+ if hasattr(result, "usage") and result.usage:
619
+ tokens = result.usage.total_tokens if hasattr(result.usage, "total_tokens") else 0
620
+ context.budget_tracker.add_tokens("graph_execution", tokens)
621
+
622
+ return output
623
+
624
+ async def _execute_state_node(
625
+ self, node: StateNode, query: str, context: GraphExecutionContext
626
+ ) -> Any:
627
+ """Execute a state node.
628
+
629
+ Special handling for deep research state nodes:
630
+ - "store_plan": Stores ReportPlan in context for parallel loops
631
+ - "collect_drafts": Stores section drafts in context for synthesizer
632
+
633
+ Args:
634
+ node: The state node
635
+ query: The research query
636
+ context: Execution context
637
+
638
+ Returns:
639
+ State update result
640
+ """
641
+ # Get previous result for state update
642
+ # For "store_plan", get from planner node
643
+ # For "collect_drafts", get from parallel_loops node
644
+ if node.node_id == "store_plan":
645
+ prev_result = context.get_node_result("planner")
646
+ elif node.node_id == "collect_drafts":
647
+ prev_result = context.get_node_result("parallel_loops")
648
+ else:
649
+ prev_result = context.get_node_result(context.current_node)
650
+
651
+ # Update state
652
+ updated_state = node.state_updater(context.state, prev_result)
653
+ context.state = updated_state
654
+
655
+ # Store result in context for next nodes to access
656
+ context.set_node_result(node.node_id, prev_result)
657
+
658
+ # Read state if needed
659
+ if node.state_reader:
660
+ return node.state_reader(context.state)
661
+
662
+ return prev_result # Return the stored result for next nodes
663
+
664
+ async def _execute_decision_node(
665
+ self, node: DecisionNode, query: str, context: GraphExecutionContext
666
+ ) -> str:
667
+ """Execute a decision node.
668
+
669
+ Args:
670
+ node: The decision node
671
+ query: The research query
672
+ context: Execution context
673
+
674
+ Returns:
675
+ Next node ID
676
+ """
677
+ # Get previous result for decision
678
+ prev_result = context.get_node_result(context.current_node)
679
+
680
+ # Make decision
681
+ next_node_id = node.decision_function(prev_result)
682
+
683
+ # Validate decision
684
+ if next_node_id not in node.options:
685
+ self.logger.warning(
686
+ "Decision function returned invalid node",
687
+ node_id=node.node_id,
688
+ returned=next_node_id,
689
+ options=node.options,
690
+ )
691
+ # Default to first option
692
+ next_node_id = node.options[0]
693
+
694
+ return next_node_id
695
+
696
+ async def _execute_parallel_node(
697
+ self, node: ParallelNode, query: str, context: GraphExecutionContext
698
+ ) -> list[Any]:
699
+ """Execute a parallel node.
700
+
701
+ Special handling for deep research "parallel_loops" node:
702
+ - Extracts report plan from previous node result
703
+ - Creates IterativeResearchFlow instances for each section
704
+ - Executes them in parallel
705
+ - Returns section drafts
706
+
707
+ Args:
708
+ node: The parallel node
709
+ query: The research query
710
+ context: Execution context
711
+
712
+ Returns:
713
+ List of results from parallel nodes
714
+ """
715
+ # Special handling for deep research parallel_loops node
716
+ if node.node_id == "parallel_loops":
717
+ return await self._execute_deep_research_parallel_loops(node, query, context)
718
+
719
+ # Standard parallel node execution
720
+ # Execute all parallel nodes concurrently
721
+ tasks = [
722
+ self._execute_node(parallel_node_id, query, context)
723
+ for parallel_node_id in node.parallel_nodes
724
+ ]
725
+
726
+ results = await asyncio.gather(*tasks, return_exceptions=True)
727
+
728
+ # Handle exceptions
729
+ for i, result in enumerate(results):
730
+ if isinstance(result, Exception):
731
+ self.logger.error(
732
+ "Parallel node execution failed",
733
+ node_id=node.parallel_nodes[i] if i < len(node.parallel_nodes) else "unknown",
734
+ error=str(result),
735
+ )
736
+ results[i] = None
737
+
738
+ # Aggregate if needed
739
+ if node.aggregator:
740
+ aggregated = node.aggregator(results)
741
+ # Type cast: aggregator returns Any, but we expect list[Any]
742
+ return list(aggregated) if isinstance(aggregated, list) else [aggregated]
743
+
744
+ return results
745
+
746
+ async def _execute_deep_research_parallel_loops(
747
+ self, node: ParallelNode, query: str, context: GraphExecutionContext
748
+ ) -> list[str]:
749
+ """Execute parallel iterative research loops for deep research.
750
+
751
+ Args:
752
+ node: The parallel node (should be "parallel_loops")
753
+ query: The research query
754
+ context: Execution context
755
+
756
+ Returns:
757
+ List of section draft strings
758
+ """
759
+ from src.agent_factory.judges import create_judge_handler
760
+ from src.orchestrator.research_flow import IterativeResearchFlow
761
+ from src.utils.models import ReportPlan
762
+
763
+ # Get report plan from previous node (store_plan)
764
+ # The plan should be stored in context.node_results from the planner node
765
+ planner_result = context.get_node_result("planner")
766
+ if not isinstance(planner_result, ReportPlan):
767
+ self.logger.error(
768
+ "Planner result is not a ReportPlan",
769
+ type=type(planner_result),
770
+ )
771
+ raise ValueError("Planner must return ReportPlan for deep research")
772
+
773
+ report_plan: ReportPlan = planner_result
774
+ self.logger.info(
775
+ "Executing parallel loops for deep research",
776
+ sections=len(report_plan.report_outline),
777
+ )
778
+
779
+ # Create judge handler for iterative flows
780
+ judge_handler = create_judge_handler()
781
+
782
+ # Create and execute iterative research flows for each section
783
+ async def run_section_research(section_index: int) -> str:
784
+ """Run iterative research for a single section."""
785
+ section = report_plan.report_outline[section_index]
786
+
787
+ try:
788
+ # Create iterative research flow
789
+ flow = IterativeResearchFlow(
790
+ max_iterations=self.max_iterations,
791
+ max_time_minutes=self.max_time_minutes,
792
+ verbose=False, # Less verbose in parallel execution
793
+ use_graph=False, # Use agent chains for section research
794
+ judge_handler=judge_handler,
795
+ )
796
+
797
+ # Run research for this section
798
+ section_draft = await flow.run(
799
+ query=section.key_question,
800
+ background_context=report_plan.background_context,
801
+ )
802
+
803
+ self.logger.info(
804
+ "Section research completed",
805
+ section_index=section_index,
806
+ section_title=section.title,
807
+ draft_length=len(section_draft),
808
+ )
809
+
810
+ return section_draft
811
+
812
+ except Exception as e:
813
+ self.logger.error(
814
+ "Section research failed",
815
+ section_index=section_index,
816
+ section_title=section.title,
817
+ error=str(e),
818
+ )
819
+ # Return empty string for failed sections
820
+ return f"# {section.title}\n\n[Research failed: {e!s}]"
821
+
822
+ # Execute all sections in parallel
823
+ section_drafts = await asyncio.gather(
824
+ *(run_section_research(i) for i in range(len(report_plan.report_outline))),
825
+ return_exceptions=True,
826
+ )
827
+
828
+ # Handle exceptions and filter None results
829
+ filtered_drafts: list[str] = []
830
+ for i, draft in enumerate(section_drafts):
831
+ if isinstance(draft, Exception):
832
+ self.logger.error(
833
+ "Section research exception",
834
+ section_index=i,
835
+ error=str(draft),
836
+ )
837
+ filtered_drafts.append(
838
+ f"# {report_plan.report_outline[i].title}\n\n[Research failed: {draft!s}]"
839
+ )
840
+ elif draft is not None:
841
+ # Type narrowing: after Exception check, draft is str | None
842
+ assert isinstance(draft, str), "Expected str after Exception check"
843
+ filtered_drafts.append(draft)
844
+
845
+ self.logger.info(
846
+ "Parallel loops completed",
847
+ sections=len(filtered_drafts),
848
+ total_sections=len(report_plan.report_outline),
849
+ )
850
+
851
+ return filtered_drafts
852
+
853
+ def _get_next_node(self, node_id: str, context: GraphExecutionContext) -> list[str]:
854
+ """Get next node(s) from current node.
855
+
856
+ Args:
857
+ node_id: Current node ID
858
+ context: Execution context
859
+
860
+ Returns:
861
+ List of next node IDs
862
+ """
863
+ if not self._graph:
864
+ return []
865
+
866
+ # Get node result for condition evaluation
867
+ node_result = context.get_node_result(node_id)
868
+
869
+ # Get next nodes
870
+ next_nodes = self._graph.get_next_nodes(node_id, context=node_result)
871
+
872
+ # If this was a decision node, use its result
873
+ node = self._graph.get_node(node_id)
874
+ if isinstance(node, DecisionNode):
875
+ decision_result = node_result
876
+ if isinstance(decision_result, str):
877
+ return [decision_result]
878
+
879
+ # Return next node IDs
880
+ return [next_node_id for next_node_id, _ in next_nodes]
881
+
882
+ async def _detect_research_mode(self, query: str) -> Literal["iterative", "deep"]:
883
+ """
884
+ Detect research mode from query using input parser agent.
885
+
886
+ Uses input parser agent to analyze query and determine research mode.
887
+ Falls back to heuristic if parser fails.
888
+
889
+ Args:
890
+ query: The research query
891
+
892
+ Returns:
893
+ Detected research mode
894
+ """
895
+ try:
896
+ # Use input parser agent for intelligent mode detection
897
+ input_parser = create_input_parser_agent()
898
+ parsed_query = await input_parser.parse(query)
899
+ self.logger.info(
900
+ "Research mode detected by input parser",
901
+ mode=parsed_query.research_mode,
902
+ query=query[:100],
903
+ )
904
+ return parsed_query.research_mode
905
+ except Exception as e:
906
+ # Fallback to heuristic if parser fails
907
+ self.logger.warning(
908
+ "Input parser failed, using heuristic",
909
+ error=str(e),
910
+ query=query[:100],
911
+ )
912
+ query_lower = query.lower()
913
+ if any(
914
+ keyword in query_lower
915
+ for keyword in [
916
+ "section",
917
+ "sections",
918
+ "report",
919
+ "outline",
920
+ "structure",
921
+ "comprehensive",
922
+ "analyze",
923
+ "analysis",
924
+ ]
925
+ ):
926
+ return "deep"
927
+ return "iterative"
928
+
929
+
930
+ def create_graph_orchestrator(
931
+ mode: Literal["iterative", "deep", "auto"] = "auto",
932
+ max_iterations: int = 5,
933
+ max_time_minutes: int = 10,
934
+ use_graph: bool = True,
935
+ ) -> GraphOrchestrator:
936
+ """
937
+ Factory function to create a graph orchestrator.
938
+
939
+ Args:
940
+ mode: Research mode
941
+ max_iterations: Maximum iterations per loop
942
+ max_time_minutes: Maximum time per loop
943
+ use_graph: Whether to use graph execution (True) or agent chains (False)
944
+
945
+ Returns:
946
+ Configured GraphOrchestrator instance
947
+ """
948
+ return GraphOrchestrator(
949
+ mode=mode,
950
+ max_iterations=max_iterations,
951
+ max_time_minutes=max_time_minutes,
952
+ use_graph=use_graph,
953
+ )
src/orchestrator/planner_agent.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Planner agent for creating report plans with sections and background context.
2
+
3
+ Converts the folder/planner_agent.py implementation to use Pydantic AI.
4
+ """
5
+
6
+ from datetime import datetime
7
+ from typing import Any
8
+
9
+ import structlog
10
+ from pydantic_ai import Agent
11
+
12
+ from src.agent_factory.judges import get_model
13
+ from src.tools.crawl_adapter import crawl_website
14
+ from src.tools.web_search_adapter import web_search
15
+ from src.utils.exceptions import ConfigurationError, JudgeError
16
+ from src.utils.models import ReportPlan, ReportPlanSection
17
+
18
+ logger = structlog.get_logger()
19
+
20
+
21
+ # System prompt for the planner agent
22
+ SYSTEM_PROMPT = f"""
23
+ You are a research manager, managing a team of research agents. Today's date is {datetime.now().strftime("%Y-%m-%d")}.
24
+ Given a research query, your job is to produce an initial outline of the report (section titles and key questions),
25
+ as well as some background context. Each section will be assigned to a different researcher in your team who will then
26
+ carry out research on the section.
27
+
28
+ You will be given:
29
+ - An initial research query
30
+
31
+ Your task is to:
32
+ 1. Produce 1-2 paragraphs of initial background context (if needed) on the query by running web searches or crawling websites
33
+ 2. Produce an outline of the report that includes a list of section titles and the key question to be addressed in each section
34
+ 3. Provide a title for the report that will be used as the main heading
35
+
36
+ Guidelines:
37
+ - Each section should cover a single topic/question that is independent of other sections
38
+ - The key question for each section should include both the NAME and DOMAIN NAME / WEBSITE (if available and applicable) if it is related to a company, product or similar
39
+ - The background_context should not be more than 2 paragraphs
40
+ - The background_context should be very specific to the query and include any information that is relevant for researchers across all sections of the report
41
+ - The background_context should be drawn only from web search or crawl results rather than prior knowledge (i.e. it should only be included if you have called tools)
42
+ - For example, if the query is about a company, the background context should include some basic information about what the company does
43
+ - DO NOT do more than 2 tool calls
44
+
45
+ Only output JSON. Follow the JSON schema for ReportPlan. Do not output anything else.
46
+ """
47
+
48
+
49
+ class PlannerAgent:
50
+ """
51
+ Planner agent that creates report plans with sections and background context.
52
+
53
+ Uses Pydantic AI to generate structured ReportPlan output with optional
54
+ web search and crawl tool usage for background context.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ model: Any | None = None,
60
+ web_search_tool: Any | None = None,
61
+ crawl_tool: Any | None = None,
62
+ ) -> None:
63
+ """
64
+ Initialize the planner agent.
65
+
66
+ Args:
67
+ model: Optional Pydantic AI model. If None, uses config default.
68
+ web_search_tool: Optional web search tool function. If None, uses default.
69
+ crawl_tool: Optional crawl tool function. If None, uses default.
70
+ """
71
+ self.model = model or get_model()
72
+ self.web_search_tool = web_search_tool or web_search
73
+ self.crawl_tool = crawl_tool or crawl_website
74
+ self.logger = logger
75
+
76
+ # Validate tools are callable
77
+ if not callable(self.web_search_tool):
78
+ raise ConfigurationError("web_search_tool must be callable")
79
+ if not callable(self.crawl_tool):
80
+ raise ConfigurationError("crawl_tool must be callable")
81
+
82
+ # Initialize Pydantic AI Agent
83
+ self.agent = Agent(
84
+ model=self.model,
85
+ output_type=ReportPlan,
86
+ system_prompt=SYSTEM_PROMPT,
87
+ tools=[self.web_search_tool, self.crawl_tool],
88
+ retries=3,
89
+ )
90
+
91
+ async def run(self, query: str) -> ReportPlan:
92
+ """
93
+ Run the planner agent to generate a report plan.
94
+
95
+ Args:
96
+ query: The user's research query
97
+
98
+ Returns:
99
+ ReportPlan with sections, background context, and report title
100
+
101
+ Raises:
102
+ JudgeError: If planning fails after retries
103
+ ConfigurationError: If agent configuration is invalid
104
+ """
105
+ self.logger.info("Starting report planning", query=query[:100])
106
+
107
+ user_message = f"QUERY: {query}"
108
+
109
+ try:
110
+ # Run the agent
111
+ result = await self.agent.run(user_message)
112
+ report_plan = result.output
113
+
114
+ # Validate report plan
115
+ if not report_plan.report_outline:
116
+ self.logger.warning("Report plan has no sections", query=query[:100])
117
+ raise JudgeError("Report plan must have at least one section")
118
+
119
+ if not report_plan.report_title:
120
+ self.logger.warning("Report plan has no title", query=query[:100])
121
+ raise JudgeError("Report plan must have a title")
122
+
123
+ self.logger.info(
124
+ "Report plan created",
125
+ sections=len(report_plan.report_outline),
126
+ has_background=bool(report_plan.background_context),
127
+ )
128
+
129
+ return report_plan
130
+
131
+ except Exception as e:
132
+ self.logger.error("Planning failed", error=str(e), query=query[:100])
133
+
134
+ # Fallback: return minimal report plan
135
+ if isinstance(e, JudgeError | ConfigurationError):
136
+ raise
137
+
138
+ # For other errors, return a minimal plan
139
+ return ReportPlan(
140
+ background_context="",
141
+ report_outline=[
142
+ ReportPlanSection(
143
+ title="Research Findings",
144
+ key_question=query,
145
+ )
146
+ ],
147
+ report_title=f"Research Report: {query[:50]}",
148
+ )
149
+
150
+
151
+ def create_planner_agent(model: Any | None = None) -> PlannerAgent:
152
+ """
153
+ Factory function to create a planner agent.
154
+
155
+ Args:
156
+ model: Optional Pydantic AI model. If None, uses settings default.
157
+
158
+ Returns:
159
+ Configured PlannerAgent instance
160
+
161
+ Raises:
162
+ ConfigurationError: If required API keys are missing
163
+ """
164
+ try:
165
+ # Get model from settings if not provided
166
+ if model is None:
167
+ model = get_model()
168
+
169
+ # Create and return planner agent
170
+ return PlannerAgent(model=model)
171
+
172
+ except Exception as e:
173
+ logger.error("Failed to create planner agent", error=str(e))
174
+ raise ConfigurationError(f"Failed to create planner agent: {e}") from e
src/orchestrator/research_flow.py ADDED
@@ -0,0 +1,999 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Research flow implementations for iterative and deep research patterns.
2
+
3
+ Converts the folder/iterative_research.py and folder/deep_research.py
4
+ implementations to use Pydantic AI agents.
5
+ """
6
+
7
+ import asyncio
8
+ import time
9
+ from typing import Any
10
+
11
+ import structlog
12
+
13
+ from src.agent_factory.agents import (
14
+ create_graph_orchestrator,
15
+ create_knowledge_gap_agent,
16
+ create_long_writer_agent,
17
+ create_planner_agent,
18
+ create_proofreader_agent,
19
+ create_thinking_agent,
20
+ create_tool_selector_agent,
21
+ create_writer_agent,
22
+ )
23
+ from src.agent_factory.judges import create_judge_handler
24
+ from src.middleware.budget_tracker import BudgetTracker
25
+ from src.middleware.state_machine import get_workflow_state, init_workflow_state
26
+ from src.middleware.workflow_manager import WorkflowManager
27
+ from src.services.llamaindex_rag import LlamaIndexRAGService, get_rag_service
28
+ from src.tools.tool_executor import execute_tool_tasks
29
+ from src.utils.exceptions import ConfigurationError
30
+ from src.utils.models import (
31
+ AgentSelectionPlan,
32
+ AgentTask,
33
+ Citation,
34
+ Conversation,
35
+ Evidence,
36
+ JudgeAssessment,
37
+ KnowledgeGapOutput,
38
+ ReportDraft,
39
+ ReportDraftSection,
40
+ ReportPlan,
41
+ SourceName,
42
+ ToolAgentOutput,
43
+ )
44
+
45
+ logger = structlog.get_logger()
46
+
47
+
48
+ class IterativeResearchFlow:
49
+ """
50
+ Iterative research flow that runs a single research loop.
51
+
52
+ Pattern: Generate observations → Evaluate gaps → Select tools → Execute → Repeat
53
+ until research is complete or constraints are met.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ max_iterations: int = 5,
59
+ max_time_minutes: int = 10,
60
+ verbose: bool = True,
61
+ use_graph: bool = False,
62
+ judge_handler: Any | None = None,
63
+ ) -> None:
64
+ """
65
+ Initialize iterative research flow.
66
+
67
+ Args:
68
+ max_iterations: Maximum number of iterations
69
+ max_time_minutes: Maximum time in minutes
70
+ verbose: Whether to log progress
71
+ use_graph: Whether to use graph-based execution (True) or agent chains (False)
72
+ """
73
+ self.max_iterations = max_iterations
74
+ self.max_time_minutes = max_time_minutes
75
+ self.verbose = verbose
76
+ self.use_graph = use_graph
77
+ self.logger = logger
78
+
79
+ # Initialize agents (only needed for agent chain execution)
80
+ if not use_graph:
81
+ self.knowledge_gap_agent = create_knowledge_gap_agent()
82
+ self.tool_selector_agent = create_tool_selector_agent()
83
+ self.thinking_agent = create_thinking_agent()
84
+ self.writer_agent = create_writer_agent()
85
+ # Initialize judge handler (use provided or create new)
86
+ self.judge_handler = judge_handler or create_judge_handler()
87
+
88
+ # Initialize state (only needed for agent chain execution)
89
+ if not use_graph:
90
+ self.conversation = Conversation()
91
+ self.iteration = 0
92
+ self.start_time: float | None = None
93
+ self.should_continue = True
94
+
95
+ # Initialize budget tracker
96
+ self.budget_tracker = BudgetTracker()
97
+ self.loop_id = "iterative_flow"
98
+ self.budget_tracker.create_budget(
99
+ loop_id=self.loop_id,
100
+ tokens_limit=100000,
101
+ time_limit_seconds=max_time_minutes * 60,
102
+ iterations_limit=max_iterations,
103
+ )
104
+ self.budget_tracker.start_timer(self.loop_id)
105
+
106
+ # Initialize RAG service (lazy, may be None if unavailable)
107
+ self._rag_service: LlamaIndexRAGService | None = None
108
+
109
+ # Graph orchestrator (lazy initialization)
110
+ self._graph_orchestrator: Any = None
111
+
112
+ async def run(
113
+ self,
114
+ query: str,
115
+ background_context: str = "",
116
+ output_length: str = "",
117
+ output_instructions: str = "",
118
+ ) -> str:
119
+ """
120
+ Run the iterative research flow.
121
+
122
+ Args:
123
+ query: The research query
124
+ background_context: Optional background context
125
+ output_length: Optional description of desired output length
126
+ output_instructions: Optional additional instructions
127
+
128
+ Returns:
129
+ Final report string
130
+ """
131
+ if self.use_graph:
132
+ return await self._run_with_graph(
133
+ query, background_context, output_length, output_instructions
134
+ )
135
+ else:
136
+ return await self._run_with_chains(
137
+ query, background_context, output_length, output_instructions
138
+ )
139
+
140
+ async def _run_with_chains(
141
+ self,
142
+ query: str,
143
+ background_context: str = "",
144
+ output_length: str = "",
145
+ output_instructions: str = "",
146
+ ) -> str:
147
+ """
148
+ Run the iterative research flow using agent chains.
149
+
150
+ Args:
151
+ query: The research query
152
+ background_context: Optional background context
153
+ output_length: Optional description of desired output length
154
+ output_instructions: Optional additional instructions
155
+
156
+ Returns:
157
+ Final report string
158
+ """
159
+ self.start_time = time.time()
160
+ self.logger.info("Starting iterative research (agent chains)", query=query[:100])
161
+
162
+ # Initialize conversation with first iteration
163
+ self.conversation.add_iteration()
164
+
165
+ # Main research loop
166
+ while self.should_continue and self._check_constraints():
167
+ self.iteration += 1
168
+ self.logger.info("Starting iteration", iteration=self.iteration)
169
+
170
+ # Add new iteration to conversation
171
+ self.conversation.add_iteration()
172
+
173
+ # 1. Generate observations
174
+ await self._generate_observations(query, background_context)
175
+
176
+ # 2. Evaluate gaps
177
+ evaluation = await self._evaluate_gaps(query, background_context)
178
+
179
+ # 3. Assess with judge (after tools execute, we'll assess again)
180
+ # For now, check knowledge gap evaluation
181
+ # After tool execution, we'll do a full judge assessment
182
+
183
+ # Check if research is complete (knowledge gap agent says complete)
184
+ if evaluation.research_complete:
185
+ self.should_continue = False
186
+ self.logger.info("Research marked as complete by knowledge gap agent")
187
+ break
188
+
189
+ # 4. Select tools for next gap
190
+ next_gap = evaluation.outstanding_gaps[0] if evaluation.outstanding_gaps else query
191
+ selection_plan = await self._select_agents(next_gap, query, background_context)
192
+
193
+ # 5. Execute tools
194
+ await self._execute_tools(selection_plan.tasks)
195
+
196
+ # 6. Assess evidence sufficiency with judge
197
+ judge_assessment = await self._assess_with_judge(query)
198
+
199
+ # Check if judge says evidence is sufficient
200
+ if judge_assessment.sufficient:
201
+ self.should_continue = False
202
+ self.logger.info(
203
+ "Research marked as complete by judge",
204
+ confidence=judge_assessment.confidence,
205
+ reasoning=judge_assessment.reasoning[:100],
206
+ )
207
+ break
208
+
209
+ # Update budget tracker
210
+ self.budget_tracker.increment_iteration(self.loop_id)
211
+ self.budget_tracker.update_timer(self.loop_id)
212
+
213
+ # Create final report
214
+ report = await self._create_final_report(query, output_length, output_instructions)
215
+
216
+ elapsed = time.time() - (self.start_time or time.time())
217
+ self.logger.info(
218
+ "Iterative research completed",
219
+ iterations=self.iteration,
220
+ elapsed_minutes=elapsed / 60,
221
+ )
222
+
223
+ return report
224
+
225
+ async def _run_with_graph(
226
+ self,
227
+ query: str,
228
+ background_context: str = "",
229
+ output_length: str = "",
230
+ output_instructions: str = "",
231
+ ) -> str:
232
+ """
233
+ Run the iterative research flow using graph execution.
234
+
235
+ Args:
236
+ query: The research query
237
+ background_context: Optional background context (currently ignored in graph execution)
238
+ output_length: Optional description of desired output length (currently ignored in graph execution)
239
+ output_instructions: Optional additional instructions (currently ignored in graph execution)
240
+
241
+ Returns:
242
+ Final report string
243
+ """
244
+ self.logger.info("Starting iterative research (graph execution)", query=query[:100])
245
+
246
+ # Create graph orchestrator (lazy initialization)
247
+ if self._graph_orchestrator is None:
248
+ self._graph_orchestrator = create_graph_orchestrator(
249
+ mode="iterative",
250
+ max_iterations=self.max_iterations,
251
+ max_time_minutes=self.max_time_minutes,
252
+ use_graph=True,
253
+ )
254
+
255
+ # Run orchestrator and collect events
256
+ final_report = ""
257
+ async for event in self._graph_orchestrator.run(query):
258
+ if event.type == "complete":
259
+ final_report = event.message
260
+ break
261
+ elif event.type == "error":
262
+ self.logger.error("Graph execution error", error=event.message)
263
+ raise RuntimeError(f"Graph execution failed: {event.message}")
264
+
265
+ if not final_report:
266
+ self.logger.warning("No complete event received from graph orchestrator")
267
+ final_report = "Research completed but no report was generated."
268
+
269
+ self.logger.info("Iterative research completed (graph execution)")
270
+
271
+ return final_report
272
+
273
+ def _check_constraints(self) -> bool:
274
+ """Check if we've exceeded constraints."""
275
+ if self.iteration >= self.max_iterations:
276
+ self.logger.info("Max iterations reached", max=self.max_iterations)
277
+ return False
278
+
279
+ if self.start_time:
280
+ elapsed_minutes = (time.time() - self.start_time) / 60
281
+ if elapsed_minutes >= self.max_time_minutes:
282
+ self.logger.info("Max time reached", max=self.max_time_minutes)
283
+ return False
284
+
285
+ # Check budget tracker
286
+ self.budget_tracker.update_timer(self.loop_id)
287
+ exceeded, reason = self.budget_tracker.check_budget(self.loop_id)
288
+ if exceeded:
289
+ self.logger.info("Budget exceeded", reason=reason)
290
+ return False
291
+
292
+ return True
293
+
294
+ async def _generate_observations(self, query: str, background_context: str = "") -> str:
295
+ """Generate observations from current research state."""
296
+ # Build input prompt for token estimation
297
+ conversation_history = self.conversation.compile_conversation_history()
298
+ # Build background context section separately to avoid backslash in f-string
299
+ background_section = (
300
+ f"BACKGROUND CONTEXT:\n{background_context}\n\n" if background_context else ""
301
+ )
302
+ input_prompt = f"""
303
+ You are starting iteration {self.iteration} of your research process.
304
+
305
+ ORIGINAL QUERY:
306
+ {query}
307
+
308
+ {background_section}HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
309
+ {conversation_history or "No previous actions, findings or thoughts available."}
310
+ """
311
+
312
+ observations = await self.thinking_agent.generate_observations(
313
+ query=query,
314
+ background_context=background_context,
315
+ conversation_history=conversation_history,
316
+ iteration=self.iteration,
317
+ )
318
+
319
+ # Track tokens for this iteration
320
+ estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(input_prompt, observations)
321
+ self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens)
322
+ self.logger.debug(
323
+ "Tokens tracked for thinking agent",
324
+ iteration=self.iteration,
325
+ tokens=estimated_tokens,
326
+ )
327
+
328
+ self.conversation.set_latest_thought(observations)
329
+ return observations
330
+
331
+ async def _evaluate_gaps(self, query: str, background_context: str = "") -> KnowledgeGapOutput:
332
+ """Evaluate knowledge gaps in current research."""
333
+ if self.start_time:
334
+ elapsed_minutes = (time.time() - self.start_time) / 60
335
+ else:
336
+ elapsed_minutes = 0.0
337
+
338
+ # Build input prompt for token estimation
339
+ conversation_history = self.conversation.compile_conversation_history()
340
+ background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else ""
341
+ input_prompt = f"""
342
+ Current Iteration Number: {self.iteration}
343
+ Time Elapsed: {elapsed_minutes:.2f} minutes of maximum {self.max_time_minutes} minutes
344
+
345
+ ORIGINAL QUERY:
346
+ {query}
347
+
348
+ {background}
349
+
350
+ HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
351
+ {conversation_history or "No previous actions, findings or thoughts available."}
352
+ """
353
+
354
+ evaluation = await self.knowledge_gap_agent.evaluate(
355
+ query=query,
356
+ background_context=background_context,
357
+ conversation_history=conversation_history,
358
+ iteration=self.iteration,
359
+ time_elapsed_minutes=elapsed_minutes,
360
+ max_time_minutes=self.max_time_minutes,
361
+ )
362
+
363
+ # Track tokens for this iteration
364
+ evaluation_text = f"research_complete={evaluation.research_complete}, gaps={len(evaluation.outstanding_gaps)}"
365
+ estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(
366
+ input_prompt, evaluation_text
367
+ )
368
+ self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens)
369
+ self.logger.debug(
370
+ "Tokens tracked for knowledge gap agent",
371
+ iteration=self.iteration,
372
+ tokens=estimated_tokens,
373
+ )
374
+
375
+ if not evaluation.research_complete and evaluation.outstanding_gaps:
376
+ self.conversation.set_latest_gap(evaluation.outstanding_gaps[0])
377
+
378
+ return evaluation
379
+
380
+ async def _assess_with_judge(self, query: str) -> JudgeAssessment:
381
+ """Assess evidence sufficiency using JudgeHandler.
382
+
383
+ Args:
384
+ query: The research query
385
+
386
+ Returns:
387
+ JudgeAssessment with sufficiency evaluation
388
+ """
389
+ state = get_workflow_state()
390
+ evidence = state.evidence # Get all collected evidence
391
+
392
+ self.logger.info(
393
+ "Assessing evidence with judge",
394
+ query=query[:100],
395
+ evidence_count=len(evidence),
396
+ )
397
+
398
+ assessment = await self.judge_handler.assess(query, evidence)
399
+
400
+ # Track tokens for judge call
401
+ # Estimate tokens from query + evidence + assessment
402
+ evidence_text = "\n".join([e.content[:500] for e in evidence[:10]]) # Sample
403
+ estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(
404
+ query + evidence_text, str(assessment.reasoning)
405
+ )
406
+ self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens)
407
+
408
+ self.logger.info(
409
+ "Judge assessment complete",
410
+ sufficient=assessment.sufficient,
411
+ confidence=assessment.confidence,
412
+ recommendation=assessment.recommendation,
413
+ )
414
+
415
+ return assessment
416
+
417
+ async def _select_agents(
418
+ self, gap: str, query: str, background_context: str = ""
419
+ ) -> AgentSelectionPlan:
420
+ """Select tools to address knowledge gap."""
421
+ # Build input prompt for token estimation
422
+ conversation_history = self.conversation.compile_conversation_history()
423
+ background = f"BACKGROUND CONTEXT:\n{background_context}" if background_context else ""
424
+ input_prompt = f"""
425
+ ORIGINAL QUERY:
426
+ {query}
427
+
428
+ KNOWLEDGE GAP TO ADDRESS:
429
+ {gap}
430
+
431
+ {background}
432
+
433
+ HISTORY OF ACTIONS, FINDINGS AND THOUGHTS:
434
+ {conversation_history or "No previous actions, findings or thoughts available."}
435
+ """
436
+
437
+ selection_plan = await self.tool_selector_agent.select_tools(
438
+ gap=gap,
439
+ query=query,
440
+ background_context=background_context,
441
+ conversation_history=conversation_history,
442
+ )
443
+
444
+ # Track tokens for this iteration
445
+ selection_text = f"tasks={len(selection_plan.tasks)}, agents={[task.agent for task in selection_plan.tasks]}"
446
+ estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(
447
+ input_prompt, selection_text
448
+ )
449
+ self.budget_tracker.add_iteration_tokens(self.loop_id, self.iteration, estimated_tokens)
450
+ self.logger.debug(
451
+ "Tokens tracked for tool selector agent",
452
+ iteration=self.iteration,
453
+ tokens=estimated_tokens,
454
+ )
455
+
456
+ # Store tool calls in conversation
457
+ tool_calls = [
458
+ f"[Agent] {task.agent} [Query] {task.query} [Entity] {task.entity_website or 'null'}"
459
+ for task in selection_plan.tasks
460
+ ]
461
+ self.conversation.set_latest_tool_calls(tool_calls)
462
+
463
+ return selection_plan
464
+
465
+ def _get_rag_service(self) -> LlamaIndexRAGService | None:
466
+ """
467
+ Get or create RAG service instance.
468
+
469
+ Returns:
470
+ RAG service instance, or None if unavailable
471
+ """
472
+ if self._rag_service is None:
473
+ try:
474
+ self._rag_service = get_rag_service()
475
+ self.logger.info("RAG service initialized for research flow")
476
+ except (ConfigurationError, ImportError) as e:
477
+ self.logger.warning(
478
+ "RAG service unavailable", error=str(e), hint="OPENAI_API_KEY required"
479
+ )
480
+ return None
481
+ return self._rag_service
482
+
483
+ async def _execute_tools(self, tasks: list[AgentTask]) -> dict[str, ToolAgentOutput]:
484
+ """Execute selected tools concurrently."""
485
+ try:
486
+ results = await execute_tool_tasks(tasks)
487
+ except Exception as e:
488
+ # Handle tool execution errors gracefully
489
+ self.logger.error(
490
+ "Tool execution failed",
491
+ error=str(e),
492
+ task_count=len(tasks),
493
+ exc_info=True,
494
+ )
495
+ # Return empty results to allow research flow to continue
496
+ # The flow can still generate a report based on previous iterations
497
+ results = {}
498
+
499
+ # Store findings in conversation (only if we have results)
500
+ evidence_list: list[Evidence] = []
501
+ if results:
502
+ findings = [result.output for result in results.values()]
503
+ self.conversation.set_latest_findings(findings)
504
+
505
+ # Convert tool outputs to Evidence objects and store in workflow state
506
+ evidence_list = self._convert_tool_outputs_to_evidence(results)
507
+
508
+ if evidence_list:
509
+ state = get_workflow_state()
510
+ added_count = state.add_evidence(evidence_list)
511
+ self.logger.info(
512
+ "Evidence added to workflow state",
513
+ count=added_count,
514
+ total_evidence=len(state.evidence),
515
+ )
516
+
517
+ # Ingest evidence into RAG if available (Phase 6 requirement)
518
+ rag_service = self._get_rag_service()
519
+ if rag_service is not None:
520
+ try:
521
+ # ingest_evidence is synchronous, run in executor to avoid blocking
522
+ loop = asyncio.get_event_loop()
523
+ await loop.run_in_executor(None, rag_service.ingest_evidence, evidence_list)
524
+ self.logger.info(
525
+ "Evidence ingested into RAG",
526
+ count=len(evidence_list),
527
+ )
528
+ except Exception as e:
529
+ # Don't fail the research loop if RAG ingestion fails
530
+ self.logger.warning(
531
+ "Failed to ingest evidence into RAG",
532
+ error=str(e),
533
+ count=len(evidence_list),
534
+ )
535
+
536
+ return results
537
+
538
+ def _convert_tool_outputs_to_evidence(
539
+ self, tool_results: dict[str, ToolAgentOutput]
540
+ ) -> list[Evidence]:
541
+ """Convert ToolAgentOutput to Evidence objects.
542
+
543
+ Args:
544
+ tool_results: Dictionary of tool execution results
545
+
546
+ Returns:
547
+ List of Evidence objects
548
+ """
549
+ evidence_list = []
550
+ for key, result in tool_results.items():
551
+ # Extract URLs from sources
552
+ if result.sources:
553
+ # Create one Evidence object per source URL
554
+ for url in result.sources:
555
+ # Determine source type from URL or tool name
556
+ # Default to "web" for unknown web sources
557
+ source_type: SourceName = "web"
558
+ if "pubmed" in url.lower() or "ncbi" in url.lower():
559
+ source_type = "pubmed"
560
+ elif "clinicaltrials" in url.lower():
561
+ source_type = "clinicaltrials"
562
+ elif "europepmc" in url.lower():
563
+ source_type = "europepmc"
564
+ elif "biorxiv" in url.lower():
565
+ source_type = "biorxiv"
566
+ elif "arxiv" in url.lower() or "preprint" in url.lower():
567
+ source_type = "preprint"
568
+ # Note: "web" is now a valid SourceName for general web sources
569
+
570
+ citation = Citation(
571
+ title=f"Tool Result: {key}",
572
+ url=url,
573
+ source=source_type,
574
+ date="n.d.",
575
+ authors=[],
576
+ )
577
+ # Truncate content to reasonable length for judge (1500 chars)
578
+ content = result.output[:1500]
579
+ if len(result.output) > 1500:
580
+ content += "... [truncated]"
581
+
582
+ evidence = Evidence(
583
+ content=content,
584
+ citation=citation,
585
+ relevance=0.5, # Default relevance
586
+ )
587
+ evidence_list.append(evidence)
588
+ else:
589
+ # No URLs, create a single Evidence object with tool output
590
+ # Use a placeholder URL based on the tool name
591
+ # Determine source type from tool name
592
+ tool_source_type: SourceName = "web" # Default for unknown sources
593
+ if "RAG" in key:
594
+ tool_source_type = "rag"
595
+ elif "WebSearch" in key or "SiteCrawler" in key:
596
+ tool_source_type = "web"
597
+ # "web" is now a valid SourceName for general web sources
598
+
599
+ citation = Citation(
600
+ title=f"Tool Result: {key}",
601
+ url=f"tool://{key}",
602
+ source=tool_source_type,
603
+ date="n.d.",
604
+ authors=[],
605
+ )
606
+ content = result.output[:1500]
607
+ if len(result.output) > 1500:
608
+ content += "... [truncated]"
609
+
610
+ evidence = Evidence(
611
+ content=content,
612
+ citation=citation,
613
+ relevance=0.5,
614
+ )
615
+ evidence_list.append(evidence)
616
+
617
+ return evidence_list
618
+
619
+ async def _create_final_report(
620
+ self, query: str, length: str = "", instructions: str = ""
621
+ ) -> str:
622
+ """Create final report from all findings."""
623
+ all_findings = "\n\n".join(self.conversation.get_all_findings())
624
+ if not all_findings:
625
+ all_findings = "No findings available yet."
626
+
627
+ # Build input prompt for token estimation
628
+ length_str = f"* The full response should be approximately {length}.\n" if length else ""
629
+ instructions_str = f"* {instructions}" if instructions else ""
630
+ guidelines_str = (
631
+ ("\n\nGUIDELINES:\n" + length_str + instructions_str).strip("\n")
632
+ if length or instructions
633
+ else ""
634
+ )
635
+ input_prompt = f"""
636
+ Provide a response based on the query and findings below with as much detail as possible. {guidelines_str}
637
+
638
+ QUERY: {query}
639
+
640
+ FINDINGS:
641
+ {all_findings}
642
+ """
643
+
644
+ report = await self.writer_agent.write_report(
645
+ query=query,
646
+ findings=all_findings,
647
+ output_length=length,
648
+ output_instructions=instructions,
649
+ )
650
+
651
+ # Track tokens for final report (not per iteration, just total)
652
+ estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(input_prompt, report)
653
+ self.budget_tracker.add_tokens(self.loop_id, estimated_tokens)
654
+ self.logger.debug(
655
+ "Tokens tracked for writer agent (final report)",
656
+ tokens=estimated_tokens,
657
+ )
658
+
659
+ # Note: Citation validation for markdown reports would require Evidence objects
660
+ # Currently, findings are strings, not Evidence objects. For full validation,
661
+ # consider using ResearchReport format or passing Evidence objects separately.
662
+ # See src/utils/citation_validator.py for markdown citation validation utilities.
663
+
664
+ return report
665
+
666
+
667
+ class DeepResearchFlow:
668
+ """
669
+ Deep research flow that runs parallel iterative loops per section.
670
+
671
+ Pattern: Plan → Parallel Iterative Loops (one per section) → Synthesis
672
+ """
673
+
674
+ def __init__(
675
+ self,
676
+ max_iterations: int = 5,
677
+ max_time_minutes: int = 10,
678
+ verbose: bool = True,
679
+ use_long_writer: bool = True,
680
+ use_graph: bool = False,
681
+ ) -> None:
682
+ """
683
+ Initialize deep research flow.
684
+
685
+ Args:
686
+ max_iterations: Maximum iterations per section
687
+ max_time_minutes: Maximum time per section
688
+ verbose: Whether to log progress
689
+ use_long_writer: Whether to use long writer (True) or proofreader (False)
690
+ use_graph: Whether to use graph-based execution (True) or agent chains (False)
691
+ """
692
+ self.max_iterations = max_iterations
693
+ self.max_time_minutes = max_time_minutes
694
+ self.verbose = verbose
695
+ self.use_long_writer = use_long_writer
696
+ self.use_graph = use_graph
697
+ self.logger = logger
698
+
699
+ # Initialize agents (only needed for agent chain execution)
700
+ if not use_graph:
701
+ self.planner_agent = create_planner_agent()
702
+ self.long_writer_agent = create_long_writer_agent()
703
+ self.proofreader_agent = create_proofreader_agent()
704
+ # Initialize judge handler for section loop completion
705
+ self.judge_handler = create_judge_handler()
706
+ # Initialize budget tracker for token tracking
707
+ self.budget_tracker = BudgetTracker()
708
+ self.loop_id = "deep_research_flow"
709
+ self.budget_tracker.create_budget(
710
+ loop_id=self.loop_id,
711
+ tokens_limit=200000, # Higher limit for deep research
712
+ time_limit_seconds=max_time_minutes
713
+ * 60
714
+ * 2, # Allow more time for parallel sections
715
+ iterations_limit=max_iterations * 10, # Allow for multiple sections
716
+ )
717
+ self.budget_tracker.start_timer(self.loop_id)
718
+
719
+ # Graph orchestrator (lazy initialization)
720
+ self._graph_orchestrator: Any = None
721
+
722
+ async def run(self, query: str) -> str:
723
+ """
724
+ Run the deep research flow.
725
+
726
+ Args:
727
+ query: The research query
728
+
729
+ Returns:
730
+ Final report string
731
+ """
732
+ if self.use_graph:
733
+ return await self._run_with_graph(query)
734
+ else:
735
+ return await self._run_with_chains(query)
736
+
737
+ async def _run_with_chains(self, query: str) -> str:
738
+ """
739
+ Run the deep research flow using agent chains.
740
+
741
+ Args:
742
+ query: The research query
743
+
744
+ Returns:
745
+ Final report string
746
+ """
747
+ self.logger.info("Starting deep research (agent chains)", query=query[:100])
748
+
749
+ # Initialize workflow state for deep research
750
+ try:
751
+ from src.services.embeddings import get_embedding_service
752
+
753
+ embedding_service = get_embedding_service()
754
+ except (ImportError, Exception):
755
+ # If embedding service is unavailable, initialize without it
756
+ embedding_service = None
757
+ self.logger.debug("Embedding service unavailable, initializing state without it")
758
+
759
+ init_workflow_state(embedding_service=embedding_service)
760
+ self.logger.debug("Workflow state initialized for deep research")
761
+
762
+ # 1. Build report plan
763
+ report_plan = await self._build_report_plan(query)
764
+ self.logger.info(
765
+ "Report plan created",
766
+ sections=len(report_plan.report_outline),
767
+ title=report_plan.report_title,
768
+ )
769
+
770
+ # 2. Run parallel research loops with state synchronization
771
+ section_drafts = await self._run_research_loops(report_plan)
772
+
773
+ # Verify state synchronization - log evidence count
774
+ state = get_workflow_state()
775
+ self.logger.info(
776
+ "State synchronization complete",
777
+ total_evidence=len(state.evidence),
778
+ sections_completed=len(section_drafts),
779
+ )
780
+
781
+ # 3. Create final report
782
+ final_report = await self._create_final_report(query, report_plan, section_drafts)
783
+
784
+ self.logger.info(
785
+ "Deep research completed",
786
+ sections=len(section_drafts),
787
+ final_report_length=len(final_report),
788
+ )
789
+
790
+ return final_report
791
+
792
+ async def _run_with_graph(self, query: str) -> str:
793
+ """
794
+ Run the deep research flow using graph execution.
795
+
796
+ Args:
797
+ query: The research query
798
+
799
+ Returns:
800
+ Final report string
801
+ """
802
+ self.logger.info("Starting deep research (graph execution)", query=query[:100])
803
+
804
+ # Create graph orchestrator (lazy initialization)
805
+ if self._graph_orchestrator is None:
806
+ self._graph_orchestrator = create_graph_orchestrator(
807
+ mode="deep",
808
+ max_iterations=self.max_iterations,
809
+ max_time_minutes=self.max_time_minutes,
810
+ use_graph=True,
811
+ )
812
+
813
+ # Run orchestrator and collect events
814
+ final_report = ""
815
+ async for event in self._graph_orchestrator.run(query):
816
+ if event.type == "complete":
817
+ final_report = event.message
818
+ break
819
+ elif event.type == "error":
820
+ self.logger.error("Graph execution error", error=event.message)
821
+ raise RuntimeError(f"Graph execution failed: {event.message}")
822
+
823
+ if not final_report:
824
+ self.logger.warning("No complete event received from graph orchestrator")
825
+ final_report = "Research completed but no report was generated."
826
+
827
+ self.logger.info("Deep research completed (graph execution)")
828
+
829
+ return final_report
830
+
831
+ async def _build_report_plan(self, query: str) -> ReportPlan:
832
+ """Build the initial report plan."""
833
+ self.logger.info("Building report plan")
834
+
835
+ # Build input prompt for token estimation
836
+ input_prompt = f"QUERY: {query}"
837
+
838
+ report_plan = await self.planner_agent.run(query)
839
+
840
+ # Track tokens for planner agent
841
+ if not self.use_graph and hasattr(self, "budget_tracker"):
842
+ plan_text = (
843
+ f"title={report_plan.report_title}, sections={len(report_plan.report_outline)}"
844
+ )
845
+ estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(input_prompt, plan_text)
846
+ self.budget_tracker.add_tokens(self.loop_id, estimated_tokens)
847
+ self.logger.debug(
848
+ "Tokens tracked for planner agent",
849
+ tokens=estimated_tokens,
850
+ )
851
+
852
+ self.logger.info(
853
+ "Report plan created",
854
+ sections=len(report_plan.report_outline),
855
+ has_background=bool(report_plan.background_context),
856
+ )
857
+
858
+ return report_plan
859
+
860
+ async def _run_research_loops(self, report_plan: ReportPlan) -> list[str]:
861
+ """Run parallel iterative research loops for each section."""
862
+ self.logger.info("Running research loops", sections=len(report_plan.report_outline))
863
+
864
+ # Create workflow manager for parallel execution
865
+ workflow_manager = WorkflowManager()
866
+
867
+ # Create loop configurations
868
+ loop_configs = [
869
+ {
870
+ "loop_id": f"section_{i}",
871
+ "query": section.key_question,
872
+ "section_title": section.title,
873
+ "background_context": report_plan.background_context,
874
+ }
875
+ for i, section in enumerate(report_plan.report_outline)
876
+ ]
877
+
878
+ async def run_research_for_section(config: dict[str, Any]) -> str:
879
+ """Run iterative research for a single section."""
880
+ loop_id = config.get("loop_id", "unknown")
881
+ query = config.get("query", "")
882
+ background_context = config.get("background_context", "")
883
+
884
+ try:
885
+ # Update loop status
886
+ await workflow_manager.update_loop_status(loop_id, "running")
887
+
888
+ # Create iterative research flow
889
+ flow = IterativeResearchFlow(
890
+ max_iterations=self.max_iterations,
891
+ max_time_minutes=self.max_time_minutes,
892
+ verbose=self.verbose,
893
+ use_graph=self.use_graph,
894
+ judge_handler=self.judge_handler if not self.use_graph else None,
895
+ )
896
+
897
+ # Run research
898
+ result = await flow.run(
899
+ query=query,
900
+ background_context=background_context,
901
+ )
902
+
903
+ # Sync evidence from flow to loop
904
+ state = get_workflow_state()
905
+ if state.evidence:
906
+ await workflow_manager.add_loop_evidence(loop_id, state.evidence)
907
+
908
+ # Update loop status
909
+ await workflow_manager.update_loop_status(loop_id, "completed")
910
+
911
+ return result
912
+
913
+ except Exception as e:
914
+ error_msg = str(e)
915
+ await workflow_manager.update_loop_status(loop_id, "failed", error=error_msg)
916
+ self.logger.error(
917
+ "Section research failed",
918
+ loop_id=loop_id,
919
+ error=error_msg,
920
+ )
921
+ raise
922
+
923
+ # Run all sections in parallel using workflow manager
924
+ section_drafts = await workflow_manager.run_loops_parallel(
925
+ loop_configs=loop_configs,
926
+ loop_func=run_research_for_section,
927
+ judge_handler=self.judge_handler if not self.use_graph else None,
928
+ budget_tracker=self.budget_tracker if not self.use_graph else None,
929
+ )
930
+
931
+ # Sync evidence from all loops to global state
932
+ for config in loop_configs:
933
+ loop_id = config.get("loop_id")
934
+ if loop_id:
935
+ await workflow_manager.sync_loop_evidence_to_state(loop_id)
936
+
937
+ # Filter out None results (failed loops)
938
+ section_drafts = [draft for draft in section_drafts if draft is not None]
939
+
940
+ self.logger.info(
941
+ "Research loops completed",
942
+ drafts=len(section_drafts),
943
+ total_sections=len(report_plan.report_outline),
944
+ )
945
+
946
+ return section_drafts
947
+
948
+ async def _create_final_report(
949
+ self, query: str, report_plan: ReportPlan, section_drafts: list[str]
950
+ ) -> str:
951
+ """Create final report from section drafts."""
952
+ self.logger.info("Creating final report")
953
+
954
+ # Create ReportDraft from section drafts
955
+ report_draft = ReportDraft(
956
+ sections=[
957
+ ReportDraftSection(
958
+ section_title=section.title,
959
+ section_content=draft,
960
+ )
961
+ for section, draft in zip(report_plan.report_outline, section_drafts, strict=False)
962
+ ]
963
+ )
964
+
965
+ # Build input prompt for token estimation
966
+ draft_text = "\n".join(
967
+ [s.section_content[:500] for s in report_draft.sections[:5]]
968
+ ) # Sample
969
+ input_prompt = f"QUERY: {query}\nTITLE: {report_plan.report_title}\nDRAFT: {draft_text}"
970
+
971
+ if self.use_long_writer:
972
+ # Use long writer agent
973
+ final_report = await self.long_writer_agent.write_report(
974
+ original_query=query,
975
+ report_title=report_plan.report_title,
976
+ report_draft=report_draft,
977
+ )
978
+ else:
979
+ # Use proofreader agent
980
+ final_report = await self.proofreader_agent.proofread(
981
+ query=query,
982
+ report_draft=report_draft,
983
+ )
984
+
985
+ # Track tokens for final report synthesis
986
+ if not self.use_graph and hasattr(self, "budget_tracker"):
987
+ estimated_tokens = self.budget_tracker.estimate_llm_call_tokens(
988
+ input_prompt, final_report
989
+ )
990
+ self.budget_tracker.add_tokens(self.loop_id, estimated_tokens)
991
+ self.logger.debug(
992
+ "Tokens tracked for final report synthesis",
993
+ tokens=estimated_tokens,
994
+ agent="long_writer" if self.use_long_writer else "proofreader",
995
+ )
996
+
997
+ self.logger.info("Final report created", length=len(final_report))
998
+
999
+ return final_report
src/orchestrator_factory.py CHANGED
@@ -2,7 +2,7 @@
2
 
3
  from typing import Any, Literal
4
 
5
- from src.orchestrator import JudgeHandlerProtocol, Orchestrator, SearchHandlerProtocol
6
  from src.utils.models import OrchestratorConfig
7
 
8
 
 
2
 
3
  from typing import Any, Literal
4
 
5
+ from src.legacy_orchestrator import JudgeHandlerProtocol, Orchestrator, SearchHandlerProtocol
6
  from src.utils.models import OrchestratorConfig
7
 
8
 
src/tools/__init__.py CHANGED
@@ -2,7 +2,14 @@
2
 
3
  from src.tools.base import SearchTool
4
  from src.tools.pubmed import PubMedTool
 
5
  from src.tools.search_handler import SearchHandler
6
 
7
  # Re-export
8
- __all__ = ["PubMedTool", "SearchHandler", "SearchTool"]
 
 
 
 
 
 
 
2
 
3
  from src.tools.base import SearchTool
4
  from src.tools.pubmed import PubMedTool
5
+ from src.tools.rag_tool import RAGTool, create_rag_tool
6
  from src.tools.search_handler import SearchHandler
7
 
8
  # Re-export
9
+ __all__ = [
10
+ "PubMedTool",
11
+ "SearchHandler",
12
+ "SearchTool",
13
+ "RAGTool",
14
+ "create_rag_tool",
15
+ ]
src/tools/crawl_adapter.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Website crawl tool adapter for Pydantic AI agents.
2
+
3
+ Adapts the folder/tools/crawl_website.py implementation to work with Pydantic AI.
4
+ """
5
+
6
+ import structlog
7
+
8
+ logger = structlog.get_logger()
9
+
10
+
11
+ async def crawl_website(starting_url: str) -> str:
12
+ """
13
+ Crawl a website starting from the given URL and return formatted results.
14
+
15
+ Use this tool to crawl a website for information relevant to the query.
16
+ Provide a starting URL as input.
17
+
18
+ Args:
19
+ starting_url: The starting URL to crawl (e.g., "https://example.com")
20
+
21
+ Returns:
22
+ Formatted string with crawled content including titles, descriptions, and URLs
23
+ """
24
+ try:
25
+ # Lazy import to avoid requiring folder/ dependencies at import time
26
+ from folder.tools.crawl_website import crawl_website as crawl_tool
27
+
28
+ # Call the tool function
29
+ # The tool returns List[ScrapeResult] or str
30
+ results = await crawl_tool(starting_url)
31
+
32
+ if isinstance(results, str):
33
+ # Error message returned
34
+ logger.warning("Crawl returned error", error=results)
35
+ return results
36
+
37
+ if not results:
38
+ return f"No content found when crawling: {starting_url}"
39
+
40
+ # Format results for agent consumption
41
+ formatted = [f"Found {len(results)} pages from {starting_url}:\n"]
42
+ for i, result in enumerate(results[:10], 1): # Limit to 10 pages
43
+ formatted.append(f"{i}. **{result.title or 'Untitled'}**")
44
+ if result.description:
45
+ formatted.append(f" {result.description[:200]}...")
46
+ formatted.append(f" URL: {result.url}")
47
+ if result.text:
48
+ formatted.append(f" Content: {result.text[:500]}...")
49
+ formatted.append("")
50
+
51
+ return "\n".join(formatted)
52
+
53
+ except ImportError as e:
54
+ logger.error("Crawl tool not available", error=str(e))
55
+ return f"Crawl tool not available: {e!s}"
56
+ except Exception as e:
57
+ logger.error("Crawl failed", error=str(e), url=starting_url)
58
+ return f"Error crawling website: {e!s}"
src/tools/rag_tool.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RAG tool for semantic search within collected evidence.
2
+
3
+ Implements SearchTool protocol to enable RAG as a search option in the research workflow.
4
+ """
5
+
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ import structlog
9
+
10
+ from src.utils.exceptions import ConfigurationError
11
+ from src.utils.models import Citation, Evidence, SourceName
12
+
13
+ if TYPE_CHECKING:
14
+ from src.services.llamaindex_rag import LlamaIndexRAGService
15
+
16
+ logger = structlog.get_logger()
17
+
18
+
19
+ class RAGTool:
20
+ """Search tool that uses LlamaIndex RAG for semantic search within collected evidence.
21
+
22
+ Wraps LlamaIndexRAGService to implement the SearchTool protocol.
23
+ Returns Evidence objects from RAG retrieval results.
24
+ """
25
+
26
+ def __init__(self, rag_service: "LlamaIndexRAGService | None" = None) -> None:
27
+ """
28
+ Initialize RAG tool.
29
+
30
+ Args:
31
+ rag_service: Optional RAG service instance. If None, will be lazy-initialized.
32
+ """
33
+ self._rag_service = rag_service
34
+ self.logger = logger
35
+
36
+ @property
37
+ def name(self) -> str:
38
+ """Return the tool name."""
39
+ return "rag"
40
+
41
+ def _get_rag_service(self) -> "LlamaIndexRAGService":
42
+ """
43
+ Get or create RAG service instance.
44
+
45
+ Returns:
46
+ LlamaIndexRAGService instance
47
+
48
+ Raises:
49
+ ConfigurationError: If RAG service cannot be initialized
50
+ """
51
+ if self._rag_service is None:
52
+ try:
53
+ from src.services.llamaindex_rag import get_rag_service
54
+
55
+ self._rag_service = get_rag_service()
56
+ self.logger.info("RAG service initialized")
57
+ except (ConfigurationError, ImportError) as e:
58
+ self.logger.error("Failed to initialize RAG service", error=str(e))
59
+ raise ConfigurationError("RAG service unavailable. OPENAI_API_KEY required.") from e
60
+
61
+ return self._rag_service
62
+
63
+ async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
64
+ """
65
+ Search RAG system and return evidence.
66
+
67
+ Args:
68
+ query: The search query string
69
+ max_results: Maximum number of results to return
70
+
71
+ Returns:
72
+ List of Evidence objects from RAG retrieval
73
+
74
+ Note:
75
+ Returns empty list on error (does not raise exceptions).
76
+ """
77
+ try:
78
+ rag_service = self._get_rag_service()
79
+ except ConfigurationError:
80
+ self.logger.warning("RAG service unavailable, returning empty results")
81
+ return []
82
+
83
+ try:
84
+ # Retrieve documents from RAG
85
+ retrieved_docs = rag_service.retrieve(query, top_k=max_results)
86
+
87
+ if not retrieved_docs:
88
+ self.logger.info("No RAG results found", query=query[:50])
89
+ return []
90
+
91
+ # Convert retrieved documents to Evidence objects
92
+ evidence_list: list[Evidence] = []
93
+ for doc in retrieved_docs:
94
+ try:
95
+ evidence = self._doc_to_evidence(doc)
96
+ evidence_list.append(evidence)
97
+ except Exception as e:
98
+ self.logger.warning(
99
+ "Failed to convert document to evidence",
100
+ error=str(e),
101
+ doc_text=doc.get("text", "")[:50],
102
+ )
103
+ continue
104
+
105
+ self.logger.info(
106
+ "RAG search completed",
107
+ query=query[:50],
108
+ results=len(evidence_list),
109
+ )
110
+ return evidence_list
111
+
112
+ except Exception as e:
113
+ self.logger.error("RAG search failed", error=str(e), query=query[:50])
114
+ # Return empty list on error (graceful degradation)
115
+ return []
116
+
117
+ def _doc_to_evidence(self, doc: dict[str, Any]) -> Evidence:
118
+ """
119
+ Convert RAG document to Evidence object.
120
+
121
+ Args:
122
+ doc: Document dict with keys: text, score, metadata
123
+
124
+ Returns:
125
+ Evidence object
126
+
127
+ Raises:
128
+ ValueError: If document is missing required fields
129
+ """
130
+ text = doc.get("text", "")
131
+ if not text:
132
+ raise ValueError("Document missing text content")
133
+
134
+ metadata = doc.get("metadata", {})
135
+ score = doc.get("score", 0.0)
136
+
137
+ # Extract citation information from metadata
138
+ source: SourceName = "rag" # RAG is the source
139
+ title = metadata.get("title", "Untitled")
140
+ url = metadata.get("url", "")
141
+ date = metadata.get("date", "Unknown")
142
+ authors_str = metadata.get("authors", "")
143
+ authors = [a.strip() for a in authors_str.split(",") if a.strip()] if authors_str else []
144
+
145
+ # Create citation
146
+ citation = Citation(
147
+ source=source,
148
+ title=title[:500], # Enforce max length
149
+ url=url,
150
+ date=date,
151
+ authors=authors,
152
+ )
153
+
154
+ # Create evidence with relevance score (normalize score to 0-1 if needed)
155
+ relevance = min(max(float(score), 0.0), 1.0) if score else 0.0
156
+
157
+ return Evidence(
158
+ content=text,
159
+ citation=citation,
160
+ relevance=relevance,
161
+ )
162
+
163
+
164
+ def create_rag_tool(
165
+ rag_service: "LlamaIndexRAGService | None" = None,
166
+ ) -> RAGTool:
167
+ """
168
+ Factory function to create a RAG tool.
169
+
170
+ Args:
171
+ rag_service: Optional RAG service instance. If None, will be lazy-initialized.
172
+
173
+ Returns:
174
+ Configured RAGTool instance
175
+
176
+ Raises:
177
+ ConfigurationError: If RAG service cannot be initialized and rag_service is None
178
+ """
179
+ try:
180
+ return RAGTool(rag_service=rag_service)
181
+ except Exception as e:
182
+ logger.error("Failed to create RAG tool", error=str(e))
183
+ raise ConfigurationError(f"Failed to create RAG tool: {e}") from e
src/tools/search_handler.py CHANGED
@@ -1,30 +1,74 @@
1
  """Search handler - orchestrates multiple search tools."""
2
 
3
  import asyncio
4
- from typing import cast
5
 
6
  import structlog
7
 
8
  from src.tools.base import SearchTool
9
- from src.utils.exceptions import SearchError
 
10
  from src.utils.models import Evidence, SearchResult, SourceName
11
 
 
 
 
12
  logger = structlog.get_logger()
13
 
14
 
15
  class SearchHandler:
16
  """Orchestrates parallel searches across multiple tools."""
17
 
18
- def __init__(self, tools: list[SearchTool], timeout: float = 30.0) -> None:
 
 
 
 
 
 
19
  """
20
  Initialize the search handler.
21
 
22
  Args:
23
  tools: List of search tools to use
24
  timeout: Timeout for each search in seconds
 
 
25
  """
26
- self.tools = tools
27
  self.timeout = timeout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchResult:
30
  """
@@ -66,7 +110,7 @@ class SearchHandler:
66
  sources_searched.append(tool_name)
67
  logger.info("Search tool succeeded", tool=tool.name, count=len(success_result))
68
 
69
- return SearchResult(
70
  query=query,
71
  evidence=all_evidence,
72
  sources_searched=sources_searched,
@@ -74,6 +118,24 @@ class SearchHandler:
74
  errors=errors,
75
  )
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  async def _search_with_timeout(
78
  self,
79
  tool: SearchTool,
 
1
  """Search handler - orchestrates multiple search tools."""
2
 
3
  import asyncio
4
+ from typing import TYPE_CHECKING, cast
5
 
6
  import structlog
7
 
8
  from src.tools.base import SearchTool
9
+ from src.tools.rag_tool import create_rag_tool
10
+ from src.utils.exceptions import ConfigurationError, SearchError
11
  from src.utils.models import Evidence, SearchResult, SourceName
12
 
13
+ if TYPE_CHECKING:
14
+ from src.services.llamaindex_rag import LlamaIndexRAGService
15
+
16
  logger = structlog.get_logger()
17
 
18
 
19
  class SearchHandler:
20
  """Orchestrates parallel searches across multiple tools."""
21
 
22
+ def __init__(
23
+ self,
24
+ tools: list[SearchTool],
25
+ timeout: float = 30.0,
26
+ include_rag: bool = False,
27
+ auto_ingest_to_rag: bool = True,
28
+ ) -> None:
29
  """
30
  Initialize the search handler.
31
 
32
  Args:
33
  tools: List of search tools to use
34
  timeout: Timeout for each search in seconds
35
+ include_rag: Whether to include RAG tool in searches
36
+ auto_ingest_to_rag: Whether to automatically ingest results into RAG
37
  """
38
+ self.tools = list(tools) # Make a copy
39
  self.timeout = timeout
40
+ self.auto_ingest_to_rag = auto_ingest_to_rag
41
+ self._rag_service: "LlamaIndexRAGService | None" = None
42
+
43
+ if include_rag:
44
+ self.add_rag_tool()
45
+
46
+ def add_rag_tool(self) -> None:
47
+ """Add RAG tool to the tools list if available."""
48
+ try:
49
+ rag_tool = create_rag_tool()
50
+ self.tools.append(rag_tool)
51
+ logger.info("RAG tool added to search handler")
52
+ except ConfigurationError:
53
+ logger.warning(
54
+ "RAG tool unavailable, not adding to search handler",
55
+ hint="OPENAI_API_KEY required",
56
+ )
57
+ except Exception as e:
58
+ logger.error("Failed to add RAG tool", error=str(e))
59
+
60
+ def _get_rag_service(self) -> "LlamaIndexRAGService | None":
61
+ """Get or create RAG service for ingestion."""
62
+ if self._rag_service is None and self.auto_ingest_to_rag:
63
+ try:
64
+ from src.services.llamaindex_rag import get_rag_service
65
+
66
+ self._rag_service = get_rag_service()
67
+ logger.info("RAG service initialized for ingestion")
68
+ except (ConfigurationError, ImportError):
69
+ logger.warning("RAG service unavailable for ingestion")
70
+ return None
71
+ return self._rag_service
72
 
73
  async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchResult:
74
  """
 
110
  sources_searched.append(tool_name)
111
  logger.info("Search tool succeeded", tool=tool.name, count=len(success_result))
112
 
113
+ search_result = SearchResult(
114
  query=query,
115
  evidence=all_evidence,
116
  sources_searched=sources_searched,
 
118
  errors=errors,
119
  )
120
 
121
+ # Ingest evidence into RAG if enabled and available
122
+ if self.auto_ingest_to_rag and all_evidence:
123
+ rag_service = self._get_rag_service()
124
+ if rag_service:
125
+ try:
126
+ # Filter out RAG-sourced evidence (avoid circular ingestion)
127
+ evidence_to_ingest = [e for e in all_evidence if e.citation.source != "rag"]
128
+ if evidence_to_ingest:
129
+ rag_service.ingest_evidence(evidence_to_ingest)
130
+ logger.info(
131
+ "Ingested evidence into RAG",
132
+ count=len(evidence_to_ingest),
133
+ )
134
+ except Exception as e:
135
+ logger.warning("Failed to ingest evidence into RAG", error=str(e))
136
+
137
+ return search_result
138
+
139
  async def _search_with_timeout(
140
  self,
141
  tool: SearchTool,
src/tools/tool_executor.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tool executor for running AgentTask objects.
2
+
3
+ Executes tool tasks selected by the tool selector agent and returns ToolAgentOutput.
4
+ """
5
+
6
+ import structlog
7
+
8
+ from src.tools.crawl_adapter import crawl_website
9
+ from src.tools.rag_tool import RAGTool, create_rag_tool
10
+ from src.tools.web_search_adapter import web_search
11
+ from src.utils.exceptions import ConfigurationError
12
+ from src.utils.models import AgentTask, Evidence, ToolAgentOutput
13
+
14
+ logger = structlog.get_logger()
15
+
16
+ # Module-level RAG tool instance (lazy initialization)
17
+ _rag_tool: RAGTool | None = None
18
+
19
+
20
+ def _get_rag_tool() -> RAGTool | None:
21
+ """
22
+ Get or create RAG tool instance.
23
+
24
+ Returns:
25
+ RAGTool instance, or None if unavailable
26
+ """
27
+ global _rag_tool
28
+ if _rag_tool is None:
29
+ try:
30
+ _rag_tool = create_rag_tool()
31
+ logger.info("RAG tool initialized")
32
+ except ConfigurationError:
33
+ logger.warning("RAG tool unavailable (OPENAI_API_KEY required)")
34
+ return None
35
+ except Exception as e:
36
+ logger.error("Failed to initialize RAG tool", error=str(e))
37
+ return None
38
+ return _rag_tool
39
+
40
+
41
+ def _evidence_to_text(evidence_list: list[Evidence]) -> str:
42
+ """
43
+ Convert Evidence objects to formatted text.
44
+
45
+ Args:
46
+ evidence_list: List of Evidence objects
47
+
48
+ Returns:
49
+ Formatted text string with citations and content
50
+ """
51
+ if not evidence_list:
52
+ return "No evidence found."
53
+
54
+ formatted_parts = []
55
+ for i, evidence in enumerate(evidence_list, 1):
56
+ citation = evidence.citation
57
+ citation_str = f"{citation.formatted}"
58
+ if citation.url:
59
+ citation_str += f" [{citation.url}]"
60
+
61
+ formatted_parts.append(f"[{i}] {citation_str}\n\n{evidence.content}\n\n---\n")
62
+
63
+ return "\n".join(formatted_parts)
64
+
65
+
66
+ async def execute_agent_task(task: AgentTask) -> ToolAgentOutput:
67
+ """
68
+ Execute a single agent task and return ToolAgentOutput.
69
+
70
+ Args:
71
+ task: AgentTask specifying which tool to use and what query to run
72
+
73
+ Returns:
74
+ ToolAgentOutput with results and source URLs
75
+ """
76
+ logger.info(
77
+ "Executing agent task",
78
+ agent=task.agent,
79
+ query=task.query[:100] if task.query else "",
80
+ gap=task.gap[:100] if task.gap else "",
81
+ )
82
+
83
+ try:
84
+ if task.agent == "WebSearchAgent":
85
+ # Use web search adapter
86
+ result_text = await web_search(task.query)
87
+ # Extract URLs from result (simple heuristic - look for http/https)
88
+ import re
89
+
90
+ urls = re.findall(r"https?://[^\s\)]+", result_text)
91
+ sources = list(set(urls)) # Deduplicate
92
+
93
+ return ToolAgentOutput(output=result_text, sources=sources)
94
+
95
+ elif task.agent == "SiteCrawlerAgent":
96
+ # Use crawl adapter
97
+ if task.entity_website:
98
+ starting_url = task.entity_website
99
+ elif task.query.startswith(("http://", "https://")):
100
+ starting_url = task.query
101
+ else:
102
+ # Try to construct URL from query
103
+ starting_url = f"https://{task.query}"
104
+
105
+ result_text = await crawl_website(starting_url)
106
+ # Extract URLs from result
107
+ import re
108
+
109
+ urls = re.findall(r"https?://[^\s\)]+", result_text)
110
+ sources = list(set(urls)) # Deduplicate
111
+
112
+ return ToolAgentOutput(output=result_text, sources=sources)
113
+
114
+ elif task.agent == "RAGAgent":
115
+ # Use RAG tool for semantic search
116
+ rag_tool = _get_rag_tool()
117
+ if rag_tool is None:
118
+ return ToolAgentOutput(
119
+ output="RAG service unavailable. OPENAI_API_KEY required.",
120
+ sources=[],
121
+ )
122
+
123
+ # Search RAG and get Evidence objects
124
+ evidence_list = await rag_tool.search(task.query, max_results=10)
125
+
126
+ if not evidence_list:
127
+ return ToolAgentOutput(
128
+ output="No relevant evidence found in collected research.",
129
+ sources=[],
130
+ )
131
+
132
+ # Convert Evidence to formatted text
133
+ result_text = _evidence_to_text(evidence_list)
134
+
135
+ # Extract URLs from evidence citations
136
+ sources = [evidence.citation.url for evidence in evidence_list if evidence.citation.url]
137
+
138
+ return ToolAgentOutput(output=result_text, sources=sources)
139
+
140
+ else:
141
+ logger.warning("Unknown agent type", agent=task.agent)
142
+ return ToolAgentOutput(
143
+ output=f"Unknown agent type: {task.agent}. Available: WebSearchAgent, SiteCrawlerAgent, RAGAgent",
144
+ sources=[],
145
+ )
146
+
147
+ except Exception as e:
148
+ logger.error("Tool execution failed", error=str(e), agent=task.agent)
149
+ return ToolAgentOutput(
150
+ output=f"Error executing {task.agent} for gap '{task.gap}': {e!s}",
151
+ sources=[],
152
+ )
153
+
154
+
155
+ async def execute_tool_tasks(
156
+ tasks: list[AgentTask],
157
+ ) -> dict[str, ToolAgentOutput]:
158
+ """
159
+ Execute multiple agent tasks concurrently.
160
+
161
+ Args:
162
+ tasks: List of AgentTask objects to execute
163
+
164
+ Returns:
165
+ Dictionary mapping task keys to ToolAgentOutput results
166
+ """
167
+ import asyncio
168
+
169
+ logger.info("Executing tool tasks", count=len(tasks))
170
+
171
+ # Create async tasks
172
+ async_tasks = [execute_agent_task(task) for task in tasks]
173
+
174
+ # Run concurrently
175
+ results_list = await asyncio.gather(*async_tasks, return_exceptions=True)
176
+
177
+ # Build results dictionary
178
+ results: dict[str, ToolAgentOutput] = {}
179
+ for i, (task, result) in enumerate(zip(tasks, results_list, strict=False)):
180
+ if isinstance(result, Exception):
181
+ logger.error("Task execution failed", error=str(result), task_index=i)
182
+ results[f"{task.agent}_{i}"] = ToolAgentOutput(output=f"Error: {result!s}", sources=[])
183
+ else:
184
+ # Type narrowing: result is ToolAgentOutput after Exception check
185
+ assert isinstance(
186
+ result, ToolAgentOutput
187
+ ), "Expected ToolAgentOutput after Exception check"
188
+ key = f"{task.agent}_{task.gap or i}" if task.gap else f"{task.agent}_{i}"
189
+ results[key] = result
190
+
191
+ logger.info("Tool tasks completed", completed=len(results))
192
+
193
+ return results
src/tools/web_search_adapter.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Web search tool adapter for Pydantic AI agents.
2
+
3
+ Adapts the folder/tools/web_search.py implementation to work with Pydantic AI.
4
+ """
5
+
6
+ import structlog
7
+
8
+ logger = structlog.get_logger()
9
+
10
+
11
+ async def web_search(query: str) -> str:
12
+ """
13
+ Perform a web search for a given query and return formatted results.
14
+
15
+ Use this tool to search the web for information relevant to the query.
16
+ Provide a query with 3-6 words as input.
17
+
18
+ Args:
19
+ query: The search query (3-6 words recommended)
20
+
21
+ Returns:
22
+ Formatted string with search results including titles, descriptions, and URLs
23
+ """
24
+ try:
25
+ # Lazy import to avoid requiring folder/ dependencies at import time
26
+ # This will use the existing web_search tool from folder/tools
27
+ from folder.llm_config import create_default_config
28
+ from folder.tools.web_search import create_web_search_tool
29
+
30
+ config = create_default_config()
31
+ web_search_tool = create_web_search_tool(config)
32
+
33
+ # Call the tool function
34
+ # The tool returns List[ScrapeResult] or str
35
+ results = await web_search_tool(query)
36
+
37
+ if isinstance(results, str):
38
+ # Error message returned
39
+ logger.warning("Web search returned error", error=results)
40
+ return results
41
+
42
+ if not results:
43
+ return f"No web search results found for: {query}"
44
+
45
+ # Format results for agent consumption
46
+ formatted = [f"Found {len(results)} web search results:\n"]
47
+ for i, result in enumerate(results[:5], 1): # Limit to 5 results
48
+ formatted.append(f"{i}. **{result.title}**")
49
+ if result.description:
50
+ formatted.append(f" {result.description[:200]}...")
51
+ formatted.append(f" URL: {result.url}")
52
+ if result.text:
53
+ formatted.append(f" Content: {result.text[:300]}...")
54
+ formatted.append("")
55
+
56
+ return "\n".join(formatted)
57
+
58
+ except ImportError as e:
59
+ logger.error("Web search tool not available", error=str(e))
60
+ return f"Web search tool not available: {e!s}"
61
+ except Exception as e:
62
+ logger.error("Web search failed", error=str(e), query=query)
63
+ return f"Error performing web search: {e!s}"
src/utils/citation_validator.py CHANGED
@@ -85,3 +85,94 @@ def build_reference_from_evidence(evidence: "Evidence") -> dict[str, str]:
85
  "date": evidence.citation.date or "n.d.",
86
  "url": evidence.citation.url,
87
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  "date": evidence.citation.date or "n.d.",
86
  "url": evidence.citation.url,
87
  }
88
+
89
+
90
+ def validate_markdown_citations(
91
+ markdown_report: str, evidence: list["Evidence"]
92
+ ) -> tuple[str, int]:
93
+ """Validate citations in a markdown report against collected evidence.
94
+
95
+ This function validates citations in markdown format (e.g., [1], [2]) by:
96
+ 1. Extracting URLs from the references section
97
+ 2. Matching them against Evidence objects
98
+ 3. Removing invalid citations from the report
99
+
100
+ Note:
101
+ This is a basic validation. For full validation, use ResearchReport
102
+ objects with validate_references().
103
+
104
+ Args:
105
+ markdown_report: The markdown report string with citations
106
+ evidence: List of Evidence objects collected during research
107
+
108
+ Returns:
109
+ Tuple of (validated_markdown, removed_count)
110
+ """
111
+ import re
112
+
113
+ # Build set of valid URLs from evidence
114
+ valid_urls = {e.citation.url for e in evidence}
115
+ valid_urls_lower = {url.lower() for url in valid_urls}
116
+
117
+ # Extract references section (everything after "## References" or "References:")
118
+ ref_section_pattern = r"(?i)(?:##\s*)?References:?\s*\n(.*?)(?=\n##|\Z)"
119
+ ref_match = re.search(ref_section_pattern, markdown_report, re.DOTALL)
120
+
121
+ if not ref_match:
122
+ # No references section found, return as-is
123
+ return markdown_report, 0
124
+
125
+ ref_section = ref_match.group(1)
126
+ ref_lines = ref_section.strip().split("\n")
127
+
128
+ # Parse references: [1] https://example.com or [1] https://example.com Title
129
+ valid_refs = []
130
+ removed_count = 0
131
+
132
+ for ref_line in ref_lines:
133
+ stripped_line = ref_line.strip()
134
+ if not stripped_line:
135
+ continue
136
+
137
+ # Extract URL from reference line
138
+ # Pattern: [N] URL or [N] URL Title
139
+ url_match = re.search(r"https?://[^\s\)]+", stripped_line)
140
+ if url_match:
141
+ url = url_match.group(0).rstrip(".,;")
142
+ url_lower = url.lower()
143
+
144
+ # Check if URL is valid
145
+ if url in valid_urls or url_lower in valid_urls_lower:
146
+ valid_refs.append(stripped_line)
147
+ else:
148
+ removed_count += 1
149
+ logger.warning(
150
+ f"Removed invalid citation from markdown: {url[:80]}"
151
+ + ("..." if len(url) > 80 else "")
152
+ )
153
+ else:
154
+ # No URL found, keep the line (might be formatted differently)
155
+ valid_refs.append(stripped_line)
156
+
157
+ # Rebuild references section
158
+ if valid_refs:
159
+ new_ref_section = "\n".join(valid_refs)
160
+ # Replace the old references section
161
+ validated_markdown = (
162
+ markdown_report[: ref_match.start(1)]
163
+ + new_ref_section
164
+ + markdown_report[ref_match.end(1) :]
165
+ )
166
+ else:
167
+ # No valid references, remove the entire section
168
+ validated_markdown = (
169
+ markdown_report[: ref_match.start()] + markdown_report[ref_match.end() :]
170
+ )
171
+
172
+ if removed_count > 0:
173
+ logger.info(
174
+ f"Citation validation removed {removed_count} invalid citations from markdown report. "
175
+ f"{len(valid_refs)} valid citations remain."
176
+ )
177
+
178
+ return validated_markdown, removed_count
src/utils/config.py CHANGED
@@ -41,15 +41,65 @@ class Settings(BaseSettings):
41
  default="all-MiniLM-L6-v2",
42
  description="Local sentence-transformers model (used by EmbeddingService)",
43
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # PubMed Configuration
46
  ncbi_api_key: str | None = Field(
47
  default=None, description="NCBI API key for higher rate limits"
48
  )
49
 
 
 
 
 
 
 
 
 
 
 
50
  # Agent Configuration
51
  max_iterations: int = Field(default=10, ge=1, le=50)
52
  search_timeout: int = Field(default=30, description="Seconds to wait for search")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # Logging
55
  log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
@@ -58,6 +108,34 @@ class Settings(BaseSettings):
58
  modal_token_id: str | None = Field(default=None, description="Modal token ID")
59
  modal_token_secret: str | None = Field(default=None, description="Modal token secret")
60
  chroma_db_path: str = Field(default="./chroma_db", description="ChromaDB storage path")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  @property
63
  def modal_available(self) -> bool:
@@ -102,6 +180,26 @@ class Settings(BaseSettings):
102
  """Check if any LLM API key is available."""
103
  return self.has_openai_key or self.has_anthropic_key
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  def get_settings() -> Settings:
107
  """Factory function to get settings (allows mocking in tests)."""
 
41
  default="all-MiniLM-L6-v2",
42
  description="Local sentence-transformers model (used by EmbeddingService)",
43
  )
44
+ embedding_provider: Literal["openai", "local", "huggingface"] = Field(
45
+ default="local",
46
+ description="Embedding provider to use",
47
+ )
48
+ huggingface_embedding_model: str = Field(
49
+ default="sentence-transformers/all-MiniLM-L6-v2",
50
+ description="HuggingFace embedding model ID",
51
+ )
52
+
53
+ # HuggingFace Configuration
54
+ huggingface_api_key: str | None = Field(
55
+ default=None, description="HuggingFace API token (HF_TOKEN or HUGGINGFACE_API_KEY)"
56
+ )
57
+ huggingface_model: str = Field(
58
+ default="meta-llama/Llama-3.1-8B-Instruct",
59
+ description="Default HuggingFace model ID for inference",
60
+ )
61
 
62
  # PubMed Configuration
63
  ncbi_api_key: str | None = Field(
64
  default=None, description="NCBI API key for higher rate limits"
65
  )
66
 
67
+ # Web Search Configuration
68
+ web_search_provider: Literal["serper", "searchxng", "brave", "tavily", "duckduckgo"] = Field(
69
+ default="duckduckgo",
70
+ description="Web search provider to use",
71
+ )
72
+ serper_api_key: str | None = Field(default=None, description="Serper API key for Google search")
73
+ searchxng_host: str | None = Field(default=None, description="SearchXNG host URL")
74
+ brave_api_key: str | None = Field(default=None, description="Brave Search API key")
75
+ tavily_api_key: str | None = Field(default=None, description="Tavily API key")
76
+
77
  # Agent Configuration
78
  max_iterations: int = Field(default=10, ge=1, le=50)
79
  search_timeout: int = Field(default=30, description="Seconds to wait for search")
80
+ use_graph_execution: bool = Field(
81
+ default=False, description="Use graph-based execution for research flows"
82
+ )
83
+
84
+ # Budget & Rate Limiting Configuration
85
+ default_token_limit: int = Field(
86
+ default=100000,
87
+ ge=1000,
88
+ le=1000000,
89
+ description="Default token budget per research loop",
90
+ )
91
+ default_time_limit_minutes: int = Field(
92
+ default=10,
93
+ ge=1,
94
+ le=120,
95
+ description="Default time limit per research loop (minutes)",
96
+ )
97
+ default_iterations_limit: int = Field(
98
+ default=10,
99
+ ge=1,
100
+ le=50,
101
+ description="Default iterations limit per research loop",
102
+ )
103
 
104
  # Logging
105
  log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
 
108
  modal_token_id: str | None = Field(default=None, description="Modal token ID")
109
  modal_token_secret: str | None = Field(default=None, description="Modal token secret")
110
  chroma_db_path: str = Field(default="./chroma_db", description="ChromaDB storage path")
111
+ chroma_db_persist: bool = Field(
112
+ default=True,
113
+ description="Whether to persist ChromaDB to disk",
114
+ )
115
+ chroma_db_host: str | None = Field(
116
+ default=None,
117
+ description="ChromaDB server host (for remote ChromaDB)",
118
+ )
119
+ chroma_db_port: int | None = Field(
120
+ default=None,
121
+ description="ChromaDB server port (for remote ChromaDB)",
122
+ )
123
+
124
+ # RAG Service Configuration
125
+ rag_collection_name: str = Field(
126
+ default="deepcritical_evidence",
127
+ description="ChromaDB collection name for RAG",
128
+ )
129
+ rag_similarity_top_k: int = Field(
130
+ default=5,
131
+ ge=1,
132
+ le=50,
133
+ description="Number of top results to retrieve from RAG",
134
+ )
135
+ rag_auto_ingest: bool = Field(
136
+ default=True,
137
+ description="Automatically ingest evidence into RAG",
138
+ )
139
 
140
  @property
141
  def modal_available(self) -> bool:
 
180
  """Check if any LLM API key is available."""
181
  return self.has_openai_key or self.has_anthropic_key
182
 
183
+ @property
184
+ def has_huggingface_key(self) -> bool:
185
+ """Check if HuggingFace API key is available."""
186
+ return bool(self.huggingface_api_key)
187
+
188
+ @property
189
+ def web_search_available(self) -> bool:
190
+ """Check if web search is available (either no-key provider or API key present)."""
191
+ if self.web_search_provider == "duckduckgo":
192
+ return True # No API key required
193
+ if self.web_search_provider == "serper":
194
+ return bool(self.serper_api_key)
195
+ if self.web_search_provider == "searchxng":
196
+ return bool(self.searchxng_host)
197
+ if self.web_search_provider == "brave":
198
+ return bool(self.brave_api_key)
199
+ if self.web_search_provider == "tavily":
200
+ return bool(self.tavily_api_key)
201
+ return False
202
+
203
 
204
  def get_settings() -> Settings:
205
  """Factory function to get settings (allows mocking in tests)."""
src/utils/models.py CHANGED
@@ -6,7 +6,7 @@ from typing import Any, ClassVar, Literal
6
  from pydantic import BaseModel, Field
7
 
8
  # Centralized source type - add new sources here (e.g., "biorxiv" in Phase 11)
9
- SourceName = Literal["pubmed", "clinicaltrials", "biorxiv", "europepmc", "preprint"]
10
 
11
 
12
  class Citation(BaseModel):
@@ -303,3 +303,269 @@ class OrchestratorConfig(BaseModel):
303
  max_iterations: int = Field(default=10, ge=1, le=20)
304
  max_results_per_tool: int = Field(default=10, ge=1, le=50)
305
  search_timeout: float = Field(default=30.0, ge=5.0, le=120.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from pydantic import BaseModel, Field
7
 
8
  # Centralized source type - add new sources here (e.g., "biorxiv" in Phase 11)
9
+ SourceName = Literal["pubmed", "clinicaltrials", "biorxiv", "europepmc", "preprint", "rag", "web"]
10
 
11
 
12
  class Citation(BaseModel):
 
303
  max_iterations: int = Field(default=10, ge=1, le=20)
304
  max_results_per_tool: int = Field(default=10, ge=1, le=50)
305
  search_timeout: float = Field(default=30.0, ge=5.0, le=120.0)
306
+
307
+
308
+ # Models for iterative/deep research patterns
309
+
310
+
311
+ class IterationData(BaseModel):
312
+ """Data for a single iteration of the research loop."""
313
+
314
+ gap: str = Field(description="The gap addressed in the iteration", default="")
315
+ tool_calls: list[str] = Field(description="The tool calls made", default_factory=list)
316
+ findings: list[str] = Field(
317
+ description="The findings collected from tool calls", default_factory=list
318
+ )
319
+ thought: str = Field(
320
+ description="The thinking done to reflect on the success of the iteration and next steps",
321
+ default="",
322
+ )
323
+
324
+ model_config = {"frozen": True}
325
+
326
+
327
+ class Conversation(BaseModel):
328
+ """A conversation between the user and the iterative researcher."""
329
+
330
+ history: list[IterationData] = Field(
331
+ description="The data for each iteration of the research loop",
332
+ default_factory=list,
333
+ )
334
+
335
+ def add_iteration(self, iteration_data: IterationData | None = None) -> None:
336
+ """Add a new iteration to the conversation history."""
337
+ if iteration_data is None:
338
+ iteration_data = IterationData()
339
+ self.history.append(iteration_data)
340
+
341
+ def set_latest_gap(self, gap: str) -> None:
342
+ """Set the gap for the latest iteration."""
343
+ if not self.history:
344
+ self.add_iteration()
345
+ # Use model_copy() since IterationData is frozen
346
+ self.history[-1] = self.history[-1].model_copy(update={"gap": gap})
347
+
348
+ def set_latest_tool_calls(self, tool_calls: list[str]) -> None:
349
+ """Set the tool calls for the latest iteration."""
350
+ if not self.history:
351
+ self.add_iteration()
352
+ # Use model_copy() since IterationData is frozen
353
+ self.history[-1] = self.history[-1].model_copy(update={"tool_calls": tool_calls})
354
+
355
+ def set_latest_findings(self, findings: list[str]) -> None:
356
+ """Set the findings for the latest iteration."""
357
+ if not self.history:
358
+ self.add_iteration()
359
+ # Use model_copy() since IterationData is frozen
360
+ self.history[-1] = self.history[-1].model_copy(update={"findings": findings})
361
+
362
+ def set_latest_thought(self, thought: str) -> None:
363
+ """Set the thought for the latest iteration."""
364
+ if not self.history:
365
+ self.add_iteration()
366
+ # Use model_copy() since IterationData is frozen
367
+ self.history[-1] = self.history[-1].model_copy(update={"thought": thought})
368
+
369
+ def get_latest_gap(self) -> str:
370
+ """Get the gap from the latest iteration."""
371
+ if not self.history:
372
+ return ""
373
+ return self.history[-1].gap
374
+
375
+ def get_latest_tool_calls(self) -> list[str]:
376
+ """Get the tool calls from the latest iteration."""
377
+ if not self.history:
378
+ return []
379
+ return self.history[-1].tool_calls
380
+
381
+ def get_latest_findings(self) -> list[str]:
382
+ """Get the findings from the latest iteration."""
383
+ if not self.history:
384
+ return []
385
+ return self.history[-1].findings
386
+
387
+ def get_latest_thought(self) -> str:
388
+ """Get the thought from the latest iteration."""
389
+ if not self.history:
390
+ return ""
391
+ return self.history[-1].thought
392
+
393
+ def get_all_findings(self) -> list[str]:
394
+ """Get all findings from all iterations."""
395
+ return [finding for iteration_data in self.history for finding in iteration_data.findings]
396
+
397
+ def compile_conversation_history(self) -> str:
398
+ """Compile the conversation history into a string."""
399
+ conversation = ""
400
+ for iteration_num, iteration_data in enumerate(self.history):
401
+ conversation += f"[ITERATION {iteration_num + 1}]\n\n"
402
+ if iteration_data.thought:
403
+ conversation += f"{self.get_thought_string(iteration_num)}\n\n"
404
+ if iteration_data.gap:
405
+ conversation += f"{self.get_task_string(iteration_num)}\n\n"
406
+ if iteration_data.tool_calls:
407
+ conversation += f"{self.get_action_string(iteration_num)}\n\n"
408
+ if iteration_data.findings:
409
+ conversation += f"{self.get_findings_string(iteration_num)}\n\n"
410
+
411
+ return conversation
412
+
413
+ def get_task_string(self, iteration_num: int) -> str:
414
+ """Get the task for the specified iteration."""
415
+ if iteration_num < len(self.history) and self.history[iteration_num].gap:
416
+ return (
417
+ f"<task>\nAddress this knowledge gap: "
418
+ f"{self.history[iteration_num].gap}\n</task>"
419
+ )
420
+ return ""
421
+
422
+ def get_action_string(self, iteration_num: int) -> str:
423
+ """Get the action for the specified iteration."""
424
+ if iteration_num < len(self.history) and self.history[iteration_num].tool_calls:
425
+ joined_calls = "\n".join(self.history[iteration_num].tool_calls)
426
+ return (
427
+ "<action>\nCalling the following tools to address the knowledge gap:\n"
428
+ f"{joined_calls}\n</action>"
429
+ )
430
+ return ""
431
+
432
+ def get_findings_string(self, iteration_num: int) -> str:
433
+ """Get the findings for the specified iteration."""
434
+ if iteration_num < len(self.history) and self.history[iteration_num].findings:
435
+ joined_findings = "\n\n".join(self.history[iteration_num].findings)
436
+ return f"<findings>\n{joined_findings}\n</findings>"
437
+ return ""
438
+
439
+ def get_thought_string(self, iteration_num: int) -> str:
440
+ """Get the thought for the specified iteration."""
441
+ if iteration_num < len(self.history) and self.history[iteration_num].thought:
442
+ return f"<thought>\n{self.history[iteration_num].thought}\n</thought>"
443
+ return ""
444
+
445
+ def latest_task_string(self) -> str:
446
+ """Get the latest task."""
447
+ if not self.history:
448
+ return ""
449
+ return self.get_task_string(len(self.history) - 1)
450
+
451
+ def latest_action_string(self) -> str:
452
+ """Get the latest action."""
453
+ if not self.history:
454
+ return ""
455
+ return self.get_action_string(len(self.history) - 1)
456
+
457
+ def latest_findings_string(self) -> str:
458
+ """Get the latest findings."""
459
+ if not self.history:
460
+ return ""
461
+ return self.get_findings_string(len(self.history) - 1)
462
+
463
+ def latest_thought_string(self) -> str:
464
+ """Get the latest thought."""
465
+ if not self.history:
466
+ return ""
467
+ return self.get_thought_string(len(self.history) - 1)
468
+
469
+
470
+ class ReportPlanSection(BaseModel):
471
+ """A section of the report that needs to be written."""
472
+
473
+ title: str = Field(description="The title of the section")
474
+ key_question: str = Field(description="The key question to be addressed in the section")
475
+
476
+ model_config = {"frozen": True}
477
+
478
+
479
+ class ReportPlan(BaseModel):
480
+ """Output from the Report Planner Agent."""
481
+
482
+ background_context: str = Field(
483
+ description="A summary of supporting context that can be passed onto the research agents"
484
+ )
485
+ report_outline: list[ReportPlanSection] = Field(
486
+ description="List of sections that need to be written in the report"
487
+ )
488
+ report_title: str = Field(description="The title of the report")
489
+
490
+ model_config = {"frozen": True}
491
+
492
+
493
+ class KnowledgeGapOutput(BaseModel):
494
+ """Output from the Knowledge Gap Agent."""
495
+
496
+ research_complete: bool = Field(
497
+ description="Whether the research and findings are complete enough to end the research loop"
498
+ )
499
+ outstanding_gaps: list[str] = Field(
500
+ description="List of knowledge gaps that still need to be addressed"
501
+ )
502
+
503
+ model_config = {"frozen": True}
504
+
505
+
506
+ class AgentTask(BaseModel):
507
+ """A task for a specific agent to address knowledge gaps."""
508
+
509
+ gap: str | None = Field(description="The knowledge gap being addressed", default=None)
510
+ agent: str = Field(description="The name of the agent to use")
511
+ query: str = Field(description="The specific query for the agent")
512
+ entity_website: str | None = Field(
513
+ description="The website of the entity being researched, if known",
514
+ default=None,
515
+ )
516
+
517
+ model_config = {"frozen": True}
518
+
519
+
520
+ class AgentSelectionPlan(BaseModel):
521
+ """Plan for which agents to use for knowledge gaps."""
522
+
523
+ tasks: list[AgentTask] = Field(description="List of agent tasks to address knowledge gaps")
524
+
525
+ model_config = {"frozen": True}
526
+
527
+
528
+ class ReportDraftSection(BaseModel):
529
+ """A section of the report that needs to be written."""
530
+
531
+ section_title: str = Field(description="The title of the section")
532
+ section_content: str = Field(description="The content of the section")
533
+
534
+ model_config = {"frozen": True}
535
+
536
+
537
+ class ReportDraft(BaseModel):
538
+ """Output from the Report Planner Agent."""
539
+
540
+ sections: list[ReportDraftSection] = Field(
541
+ description="List of sections that are in the report"
542
+ )
543
+
544
+ model_config = {"frozen": True}
545
+
546
+
547
+ class ToolAgentOutput(BaseModel):
548
+ """Standard output for all tool agents."""
549
+
550
+ output: str = Field(description="The output from the tool agent")
551
+ sources: list[str] = Field(description="List of source URLs", default_factory=list)
552
+
553
+ model_config = {"frozen": True}
554
+
555
+
556
+ class ParsedQuery(BaseModel):
557
+ """Parsed and improved user query with research mode detection."""
558
+
559
+ original_query: str = Field(description="The original user query")
560
+ improved_query: str = Field(description="Improved/refined query")
561
+ research_mode: Literal["iterative", "deep"] = Field(description="Detected research mode")
562
+ key_entities: list[str] = Field(
563
+ default_factory=list,
564
+ description="Key entities extracted from query",
565
+ )
566
+ research_questions: list[str] = Field(
567
+ default_factory=list,
568
+ description="Specific research questions extracted",
569
+ )
570
+
571
+ model_config = {"frozen": True}
tests/integration/test_deep_research.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Integration tests for deep research flow.
2
+
3
+ Tests the complete deep research pattern: plan → parallel loops → synthesis.
4
+ """
5
+
6
+ from unittest.mock import AsyncMock, patch
7
+
8
+ import pytest
9
+
10
+ from src.middleware.state_machine import init_workflow_state
11
+ from src.orchestrator.research_flow import DeepResearchFlow
12
+ from src.utils.models import ReportPlan, ReportPlanSection
13
+
14
+
15
+ @pytest.mark.integration
16
+ class TestDeepResearchFlow:
17
+ """Integration tests for DeepResearchFlow."""
18
+
19
+ @pytest.mark.asyncio
20
+ async def test_deep_research_creates_plan(self) -> None:
21
+ """Test that deep research creates a report plan."""
22
+ # Initialize workflow state
23
+ init_workflow_state()
24
+
25
+ flow = DeepResearchFlow(
26
+ max_iterations=2,
27
+ max_time_minutes=5,
28
+ verbose=False,
29
+ use_graph=False,
30
+ )
31
+
32
+ # Mock the planner agent to return a simple plan
33
+ mock_plan = ReportPlan(
34
+ background_context="Test background context",
35
+ report_outline=[
36
+ ReportPlanSection(
37
+ title="Section 1",
38
+ key_question="What is the first question?",
39
+ ),
40
+ ReportPlanSection(
41
+ title="Section 2",
42
+ key_question="What is the second question?",
43
+ ),
44
+ ],
45
+ report_title="Test Report",
46
+ )
47
+
48
+ flow.planner_agent.run = AsyncMock(return_value=mock_plan)
49
+
50
+ # Mock the iterative research flows to return simple drafts
51
+ async def mock_iterative_run(query: str, **kwargs: dict) -> str:
52
+ return f"# Draft for: {query}\n\nThis is a test draft."
53
+
54
+ # Mock the long writer to return a simple report
55
+ flow.long_writer_agent.write_report = AsyncMock(
56
+ return_value="# Test Report\n\n## Section 1\n\nDraft 1\n\n## Section 2\n\nDraft 2"
57
+ )
58
+
59
+ # We can't easily mock the IterativeResearchFlow.run() without more setup
60
+ # So we'll test the plan creation separately
61
+ plan = await flow._build_report_plan("Test query")
62
+
63
+ assert isinstance(plan, ReportPlan)
64
+ assert plan.report_title == "Test Report"
65
+ assert len(plan.report_outline) == 2
66
+ assert plan.report_outline[0].title == "Section 1"
67
+
68
+ @pytest.mark.asyncio
69
+ async def test_deep_research_parallel_loops_state_synchronization(self) -> None:
70
+ """Test that parallel loops properly synchronize state."""
71
+ # Initialize workflow state
72
+ state = init_workflow_state()
73
+
74
+ flow = DeepResearchFlow(
75
+ max_iterations=1,
76
+ max_time_minutes=2,
77
+ verbose=False,
78
+ use_graph=False,
79
+ )
80
+
81
+ # Create a simple report plan
82
+ report_plan = ReportPlan(
83
+ background_context="Test background",
84
+ report_outline=[
85
+ ReportPlanSection(
86
+ title="Section 1",
87
+ key_question="Question 1?",
88
+ ),
89
+ ReportPlanSection(
90
+ title="Section 2",
91
+ key_question="Question 2?",
92
+ ),
93
+ ],
94
+ report_title="Test Report",
95
+ )
96
+
97
+ # Mock iterative research flows to add evidence to state
98
+ from src.utils.models import Citation, Evidence
99
+
100
+ async def mock_iterative_run(query: str, **kwargs: dict) -> str:
101
+ # Add evidence to state to test synchronization
102
+ ev = Evidence(
103
+ content=f"Evidence for {query}",
104
+ citation=Citation(
105
+ source="pubmed",
106
+ title=f"Title for {query}",
107
+ url=f"https://example.com/{query.replace('?', '').replace(' ', '_')}",
108
+ date="2024-01-01",
109
+ ),
110
+ )
111
+ state.add_evidence([ev])
112
+ return f"# Draft: {query}\n\nTest content."
113
+
114
+ # Patch IterativeResearchFlow.run
115
+ with patch(
116
+ "src.orchestrator.research_flow.IterativeResearchFlow.run",
117
+ side_effect=mock_iterative_run,
118
+ ):
119
+ section_drafts = await flow._run_research_loops(report_plan)
120
+
121
+ # Verify parallel execution
122
+ assert len(section_drafts) == 2
123
+ assert "Question 1" in section_drafts[0]
124
+ assert "Question 2" in section_drafts[1]
125
+
126
+ # Verify state has evidence from both sections
127
+ # Note: In real execution, evidence would be synced via WorkflowManager
128
+ # This test verifies the structure works
129
+
130
+ @pytest.mark.asyncio
131
+ async def test_deep_research_synthesizes_final_report(self) -> None:
132
+ """Test that deep research synthesizes final report from section drafts."""
133
+ flow = DeepResearchFlow(
134
+ max_iterations=1,
135
+ max_time_minutes=2,
136
+ verbose=False,
137
+ use_graph=False,
138
+ use_long_writer=True,
139
+ )
140
+
141
+ # Create report plan
142
+ report_plan = ReportPlan(
143
+ background_context="Test background",
144
+ report_outline=[
145
+ ReportPlanSection(
146
+ title="Introduction",
147
+ key_question="What is the topic?",
148
+ ),
149
+ ReportPlanSection(
150
+ title="Conclusion",
151
+ key_question="What are the conclusions?",
152
+ ),
153
+ ],
154
+ report_title="Test Report",
155
+ )
156
+
157
+ # Create section drafts
158
+ section_drafts = [
159
+ "# Introduction\n\nThis is the introduction section.",
160
+ "# Conclusion\n\nThis is the conclusion section.",
161
+ ]
162
+
163
+ # Mock long writer
164
+ flow.long_writer_agent.write_report = AsyncMock(
165
+ return_value="# Test Report\n\n## Introduction\n\nContent\n\n## Conclusion\n\nContent"
166
+ )
167
+
168
+ final_report = await flow._create_final_report("Test query", report_plan, section_drafts)
169
+
170
+ assert isinstance(final_report, str)
171
+ assert "Test Report" in final_report
172
+ # Verify long writer was called with correct parameters
173
+ flow.long_writer_agent.write_report.assert_called_once()
174
+ call_args = flow.long_writer_agent.write_report.call_args
175
+ assert call_args.kwargs["original_query"] == "Test query"
176
+ assert call_args.kwargs["report_title"] == "Test Report"
177
+ assert len(call_args.kwargs["report_draft"].sections) == 2
178
+
179
+ @pytest.mark.asyncio
180
+ async def test_deep_research_agent_chains_full_flow(self) -> None:
181
+ """Test full deep research flow with agent chains (mocked)."""
182
+ # Initialize workflow state
183
+ init_workflow_state()
184
+
185
+ flow = DeepResearchFlow(
186
+ max_iterations=1,
187
+ max_time_minutes=2,
188
+ verbose=False,
189
+ use_graph=False,
190
+ )
191
+
192
+ # Mock all agents
193
+ mock_plan = ReportPlan(
194
+ background_context="Background",
195
+ report_outline=[
196
+ ReportPlanSection(
197
+ title="Section 1",
198
+ key_question="Question 1?",
199
+ ),
200
+ ],
201
+ report_title="Test Report",
202
+ )
203
+
204
+ flow.planner_agent.run = AsyncMock(return_value=mock_plan)
205
+
206
+ # Mock iterative research
207
+ async def mock_iterative_run(query: str, **kwargs: dict) -> str:
208
+ return f"# Draft\n\nAnswer to {query}"
209
+
210
+ with patch(
211
+ "src.orchestrator.research_flow.IterativeResearchFlow.run",
212
+ side_effect=mock_iterative_run,
213
+ ):
214
+ flow.long_writer_agent.write_report = AsyncMock(
215
+ return_value="# Test Report\n\n## Section 1\n\nDraft content"
216
+ )
217
+
218
+ # Run the full flow
219
+ result = await flow._run_with_chains("Test query")
220
+
221
+ assert isinstance(result, str)
222
+ assert "Test Report" in result
223
+ flow.planner_agent.run.assert_called_once()
224
+ flow.long_writer_agent.write_report.assert_called_once()
225
+
226
+ @pytest.mark.asyncio
227
+ async def test_deep_research_handles_multiple_sections(self) -> None:
228
+ """Test that deep research handles multiple sections correctly."""
229
+ flow = DeepResearchFlow(
230
+ max_iterations=1,
231
+ max_time_minutes=2,
232
+ verbose=False,
233
+ use_graph=False,
234
+ )
235
+
236
+ # Create plan with multiple sections
237
+ report_plan = ReportPlan(
238
+ background_context="Background",
239
+ report_outline=[
240
+ ReportPlanSection(
241
+ title=f"Section {i}",
242
+ key_question=f"Question {i}?",
243
+ )
244
+ for i in range(5) # 5 sections
245
+ ],
246
+ report_title="Multi-Section Report",
247
+ )
248
+
249
+ # Mock iterative research to return unique drafts
250
+ async def mock_iterative_run(query: str, **kwargs: dict) -> str:
251
+ section_num = query.split()[-1].replace("?", "")
252
+ return f"# Section {section_num} Draft\n\nContent for section {section_num}"
253
+
254
+ with patch(
255
+ "src.orchestrator.research_flow.IterativeResearchFlow.run",
256
+ side_effect=mock_iterative_run,
257
+ ):
258
+ section_drafts = await flow._run_research_loops(report_plan)
259
+
260
+ # Verify all sections were processed
261
+ assert len(section_drafts) == 5
262
+ for i, draft in enumerate(section_drafts):
263
+ assert f"Section {i}" in draft or f"section {i}" in draft.lower()
264
+
265
+ @pytest.mark.asyncio
266
+ async def test_deep_research_workflow_manager_integration(self) -> None:
267
+ """Test that deep research properly uses WorkflowManager."""
268
+
269
+ # Initialize workflow state
270
+ init_workflow_state()
271
+
272
+ flow = DeepResearchFlow(
273
+ max_iterations=1,
274
+ max_time_minutes=2,
275
+ verbose=False,
276
+ use_graph=False,
277
+ )
278
+
279
+ # Create report plan
280
+ report_plan = ReportPlan(
281
+ background_context="Background",
282
+ report_outline=[
283
+ ReportPlanSection(
284
+ title="Section 1",
285
+ key_question="Question 1?",
286
+ ),
287
+ ReportPlanSection(
288
+ title="Section 2",
289
+ key_question="Question 2?",
290
+ ),
291
+ ],
292
+ report_title="Test Report",
293
+ )
294
+
295
+ # Mock iterative research
296
+ async def mock_iterative_run(query: str, **kwargs: dict) -> str:
297
+ return f"# Draft: {query}"
298
+
299
+ with patch(
300
+ "src.orchestrator.research_flow.IterativeResearchFlow.run",
301
+ side_effect=mock_iterative_run,
302
+ ):
303
+ section_drafts = await flow._run_research_loops(report_plan)
304
+
305
+ # Verify WorkflowManager was used (section_drafts should be returned)
306
+ assert len(section_drafts) == 2
307
+ # Each draft should be a string
308
+ assert all(isinstance(draft, str) for draft in section_drafts)
309
+
310
+ @pytest.mark.asyncio
311
+ async def test_deep_research_state_initialization(self) -> None:
312
+ """Test that deep research properly initializes workflow state."""
313
+ flow = DeepResearchFlow(
314
+ max_iterations=1,
315
+ max_time_minutes=2,
316
+ verbose=False,
317
+ use_graph=False,
318
+ )
319
+
320
+ # Mock the planner
321
+ mock_plan = ReportPlan(
322
+ background_context="Background",
323
+ report_outline=[
324
+ ReportPlanSection(
325
+ title="Section 1",
326
+ key_question="Question 1?",
327
+ ),
328
+ ],
329
+ report_title="Test Report",
330
+ )
331
+
332
+ flow.planner_agent.run = AsyncMock(return_value=mock_plan)
333
+
334
+ # Mock iterative research
335
+ async def mock_iterative_run(query: str, **kwargs: dict) -> str:
336
+ return "# Draft"
337
+
338
+ with patch(
339
+ "src.orchestrator.research_flow.IterativeResearchFlow.run",
340
+ side_effect=mock_iterative_run,
341
+ ):
342
+ flow.long_writer_agent.write_report = AsyncMock(return_value="# Test Report\n\nContent")
343
+
344
+ # Run with chains - should initialize state
345
+ # Note: _run_with_chains handles missing embedding service gracefully
346
+ await flow._run_with_chains("Test query")
347
+
348
+ # Verify state was initialized (get_workflow_state should not raise)
349
+ from src.middleware.state_machine import get_workflow_state
350
+
351
+ state = get_workflow_state()
352
+ assert state is not None
tests/integration/test_middleware_integration.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Integration tests for middleware components.
2
+
3
+ Tests the interaction between WorkflowState, WorkflowManager, and BudgetTracker.
4
+ """
5
+
6
+ import pytest
7
+
8
+ from src.middleware.budget_tracker import BudgetTracker
9
+ from src.middleware.state_machine import init_workflow_state
10
+ from src.middleware.workflow_manager import WorkflowManager
11
+ from src.utils.models import Citation, Evidence
12
+
13
+
14
+ @pytest.mark.integration
15
+ class TestMiddlewareIntegration:
16
+ """Integration tests for middleware components."""
17
+
18
+ @pytest.mark.asyncio
19
+ async def test_state_manager_integration(self) -> None:
20
+ """Test WorkflowState and WorkflowManager integration."""
21
+ # Initialize state
22
+ state = init_workflow_state()
23
+ manager = WorkflowManager()
24
+
25
+ # Create a loop
26
+ loop = await manager.add_loop("test_loop", "Test query")
27
+
28
+ # Add evidence to loop
29
+ ev = Evidence(
30
+ content="Test evidence",
31
+ citation=Citation(
32
+ source="pubmed", title="Test Title", url="https://example.com/1", date="2024-01-01"
33
+ ),
34
+ )
35
+ await manager.add_loop_evidence("test_loop", [ev])
36
+
37
+ # Sync to global state
38
+ await manager.sync_loop_evidence_to_state("test_loop")
39
+
40
+ # Verify state has evidence
41
+ assert len(state.evidence) == 1
42
+ assert state.evidence[0].content == "Test evidence"
43
+
44
+ # Verify loop still has evidence
45
+ loop = await manager.get_loop("test_loop")
46
+ assert loop is not None
47
+ assert len(loop.evidence) == 1
48
+
49
+ @pytest.mark.asyncio
50
+ async def test_budget_tracker_with_workflow_manager(self) -> None:
51
+ """Test BudgetTracker integration with WorkflowManager."""
52
+ manager = WorkflowManager()
53
+ tracker = BudgetTracker()
54
+
55
+ # Create loop and budget
56
+ await manager.add_loop("budget_loop", "Test query")
57
+ tracker.create_budget("budget_loop", tokens_limit=1000, time_limit_seconds=60.0)
58
+ tracker.start_timer("budget_loop")
59
+
60
+ # Simulate some work
61
+ tracker.add_tokens("budget_loop", 500)
62
+ await manager.increment_loop_iteration("budget_loop")
63
+ tracker.increment_iteration("budget_loop")
64
+
65
+ # Check budget
66
+ can_continue = tracker.can_continue("budget_loop")
67
+ assert can_continue is True
68
+
69
+ # Exceed budget
70
+ tracker.add_tokens("budget_loop", 600) # Total: 1100 > 1000
71
+ can_continue = tracker.can_continue("budget_loop")
72
+ assert can_continue is False
73
+
74
+ # Update loop status based on budget
75
+ if not can_continue:
76
+ await manager.update_loop_status("budget_loop", "cancelled")
77
+
78
+ loop = await manager.get_loop("budget_loop")
79
+ assert loop is not None
80
+ assert loop.status == "cancelled"
81
+
82
+ @pytest.mark.asyncio
83
+ async def test_parallel_loops_with_budget_tracking(self) -> None:
84
+ """Test parallel loops with budget tracking."""
85
+
86
+ async def mock_research_loop(config: dict) -> str:
87
+ """Mock research loop function."""
88
+ loop_id = config.get("loop_id", "unknown")
89
+ tracker = BudgetTracker()
90
+ manager = WorkflowManager()
91
+
92
+ # Get or create budget
93
+ budget = tracker.get_budget(loop_id)
94
+ if not budget:
95
+ tracker.create_budget(loop_id, tokens_limit=500, time_limit_seconds=10.0)
96
+ tracker.start_timer(loop_id)
97
+
98
+ # Simulate work
99
+ tracker.add_tokens(loop_id, 100)
100
+ await manager.increment_loop_iteration(loop_id)
101
+ tracker.increment_iteration(loop_id)
102
+
103
+ # Check if can continue
104
+ if not tracker.can_continue(loop_id):
105
+ await manager.update_loop_status(loop_id, "cancelled")
106
+ return f"Cancelled: {loop_id}"
107
+
108
+ await manager.update_loop_status(loop_id, "completed")
109
+ return f"Completed: {loop_id}"
110
+
111
+ manager = WorkflowManager()
112
+ tracker = BudgetTracker()
113
+
114
+ # Create budgets for all loops
115
+ configs = [
116
+ {"loop_id": "loop1", "query": "Query 1"},
117
+ {"loop_id": "loop2", "query": "Query 2"},
118
+ {"loop_id": "loop3", "query": "Query 3"},
119
+ ]
120
+
121
+ for config in configs:
122
+ loop_id = config["loop_id"]
123
+ await manager.add_loop(loop_id, config["query"])
124
+ tracker.create_budget(loop_id, tokens_limit=500, time_limit_seconds=10.0)
125
+ tracker.start_timer(loop_id)
126
+
127
+ # Run loops in parallel
128
+ results = await manager.run_loops_parallel(configs, mock_research_loop)
129
+
130
+ # Verify all loops completed
131
+ assert len(results) == 3
132
+ for config in configs:
133
+ loop_id = config["loop_id"]
134
+ loop = await manager.get_loop(loop_id)
135
+ assert loop is not None
136
+ assert loop.status in ("completed", "cancelled")
137
+
138
+ @pytest.mark.asyncio
139
+ async def test_state_conversation_integration(self) -> None:
140
+ """Test WorkflowState conversation integration."""
141
+ state = init_workflow_state()
142
+
143
+ # Add iteration data
144
+ state.conversation.add_iteration()
145
+ state.conversation.set_latest_gap("Knowledge gap 1")
146
+ state.conversation.set_latest_tool_calls(["tool1", "tool2"])
147
+ state.conversation.set_latest_findings(["finding1", "finding2"])
148
+ state.conversation.set_latest_thought("Thought about findings")
149
+
150
+ # Verify conversation history
151
+ assert len(state.conversation.history) == 1
152
+ assert state.conversation.get_latest_gap() == "Knowledge gap 1"
153
+ assert len(state.conversation.get_latest_tool_calls()) == 2
154
+ assert len(state.conversation.get_latest_findings()) == 2
155
+
156
+ # Compile history
157
+ history_str = state.conversation.compile_conversation_history()
158
+ assert "Knowledge gap 1" in history_str
159
+ assert "tool1" in history_str
160
+ assert "finding1" in history_str
161
+ assert "Thought about findings" in history_str
162
+
163
+ @pytest.mark.asyncio
164
+ async def test_multiple_iterations_with_budget(self) -> None:
165
+ """Test multiple iterations with budget enforcement."""
166
+ manager = WorkflowManager()
167
+ tracker = BudgetTracker()
168
+
169
+ loop_id = "iterative_loop"
170
+ await manager.add_loop(loop_id, "Iterative query")
171
+ tracker.create_budget(loop_id, tokens_limit=1000, iterations_limit=5)
172
+ tracker.start_timer(loop_id)
173
+
174
+ # Simulate multiple iterations
175
+ for _ in range(7): # Try 7 iterations, but limit is 5
176
+ tracker.add_tokens(loop_id, 100)
177
+ await manager.increment_loop_iteration(loop_id)
178
+ tracker.increment_iteration(loop_id)
179
+
180
+ can_continue = tracker.can_continue(loop_id)
181
+ if not can_continue:
182
+ await manager.update_loop_status(loop_id, "cancelled")
183
+ break
184
+
185
+ loop = await manager.get_loop(loop_id)
186
+ assert loop is not None
187
+ # Should be cancelled after 5 iterations
188
+ assert loop.status == "cancelled"
189
+ assert loop.iteration_count == 5
190
+
191
+ @pytest.mark.asyncio
192
+ async def test_evidence_deduplication_across_loops(self) -> None:
193
+ """Test evidence deduplication when syncing from multiple loops."""
194
+ state = init_workflow_state()
195
+ manager = WorkflowManager()
196
+
197
+ # Create two loops with same evidence
198
+ ev1 = Evidence(
199
+ content="Same content",
200
+ citation=Citation(
201
+ source="pubmed", title="Title", url="https://example.com/1", date="2024"
202
+ ),
203
+ )
204
+ ev2 = Evidence(
205
+ content="Different content",
206
+ citation=Citation(
207
+ source="pubmed", title="Title 2", url="https://example.com/2", date="2024"
208
+ ),
209
+ )
210
+
211
+ # Add to loop1
212
+ await manager.add_loop("loop1", "Query 1")
213
+ await manager.add_loop_evidence("loop1", [ev1, ev2])
214
+ await manager.sync_loop_evidence_to_state("loop1")
215
+
216
+ # Add duplicate to loop2
217
+ await manager.add_loop("loop2", "Query 2")
218
+ ev1_duplicate = Evidence(
219
+ content="Same content (duplicate)",
220
+ citation=Citation(
221
+ source="pubmed", title="Title Duplicate", url="https://example.com/1", date="2024"
222
+ ),
223
+ )
224
+ await manager.add_loop_evidence("loop2", [ev1_duplicate])
225
+ await manager.sync_loop_evidence_to_state("loop2")
226
+
227
+ # State should have only 2 unique items (deduplicated by URL)
228
+ assert len(state.evidence) == 2
229
+
230
+ @pytest.mark.asyncio
231
+ async def test_global_budget_enforcement(self) -> None:
232
+ """Test global budget enforcement across all loops."""
233
+ tracker = BudgetTracker()
234
+ tracker.set_global_budget(tokens_limit=2000, time_limit_seconds=60.0)
235
+
236
+ # Simulate multiple loops consuming global budget
237
+ tracker.add_global_tokens(500) # Loop 1
238
+ tracker.add_global_tokens(600) # Loop 2
239
+ tracker.add_global_tokens(700) # Loop 3
240
+ tracker.add_global_tokens(300) # Loop 4 - would exceed
241
+
242
+ global_budget = tracker.get_global_budget()
243
+ assert global_budget is not None
244
+ assert global_budget.tokens_used == 2100
245
+ assert global_budget.is_exceeded() is True
tests/integration/test_parallel_loops_judge.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Integration tests for Phase 7: Parallel loops with judge-based completion.
2
+
3
+ These tests verify that WorkflowManager can coordinate parallel research loops
4
+ and use the judge to determine when loops should complete.
5
+ """
6
+
7
+ from unittest.mock import AsyncMock, MagicMock, patch
8
+
9
+ import pytest
10
+
11
+ from src.middleware.workflow_manager import WorkflowManager
12
+ from src.orchestrator.research_flow import IterativeResearchFlow
13
+ from src.utils.models import Citation, Evidence, JudgeAssessment
14
+
15
+
16
+ @pytest.fixture
17
+ def mock_judge_handler():
18
+ """Create a mock judge handler."""
19
+ judge = MagicMock()
20
+ judge.assess = AsyncMock()
21
+ return judge
22
+
23
+
24
+ @pytest.fixture
25
+ def mock_iterative_flow():
26
+ """Create a mock iterative research flow."""
27
+ flow = MagicMock(spec=IterativeResearchFlow)
28
+ flow.run = AsyncMock(return_value="# Test Report\n\nContent here.")
29
+ return flow
30
+
31
+
32
+ @pytest.mark.integration
33
+ @pytest.mark.asyncio
34
+ class TestParallelLoopsWithJudge:
35
+ """Tests for parallel loops with judge-based completion."""
36
+
37
+ async def test_get_loop_evidence(self):
38
+ """get_loop_evidence should return evidence from a loop."""
39
+ manager = WorkflowManager()
40
+ await manager.add_loop("loop1", "Test query")
41
+
42
+ # Add evidence to the loop
43
+ evidence = [
44
+ Evidence(
45
+ content="Test evidence",
46
+ citation=Citation(
47
+ source="rag", # Use valid SourceName
48
+ title="Test",
49
+ url="https://example.com",
50
+ date="2024-01-01",
51
+ authors=[],
52
+ ),
53
+ relevance=0.8,
54
+ )
55
+ ]
56
+ await manager.add_loop_evidence("loop1", evidence)
57
+
58
+ # Retrieve evidence
59
+ retrieved_evidence = await manager.get_loop_evidence("loop1")
60
+ assert len(retrieved_evidence) == 1
61
+ assert retrieved_evidence[0].content == "Test evidence"
62
+
63
+ async def test_get_loop_evidence_returns_empty_for_missing_loop(self):
64
+ """get_loop_evidence should return empty list for non-existent loop."""
65
+ manager = WorkflowManager()
66
+ evidence = await manager.get_loop_evidence("nonexistent")
67
+ assert evidence == []
68
+
69
+ async def test_check_loop_completion_with_sufficient_evidence(self, mock_judge_handler):
70
+ """check_loop_completion should return True when judge says sufficient."""
71
+ manager = WorkflowManager()
72
+ await manager.add_loop("loop1", "Test query")
73
+
74
+ # Add evidence
75
+ evidence = [
76
+ Evidence(
77
+ content="Comprehensive evidence",
78
+ citation=Citation(
79
+ source="rag", # Use valid SourceName
80
+ title="Test",
81
+ url="https://example.com",
82
+ date="2024-01-01",
83
+ authors=[],
84
+ ),
85
+ relevance=0.9,
86
+ )
87
+ ]
88
+ await manager.add_loop_evidence("loop1", evidence)
89
+
90
+ # Mock judge to say sufficient
91
+ from src.utils.models import AssessmentDetails
92
+
93
+ mock_judge_handler.assess = AsyncMock(
94
+ return_value=JudgeAssessment(
95
+ details=AssessmentDetails(
96
+ mechanism_score=5,
97
+ mechanism_reasoning="Test mechanism reasoning that is long enough",
98
+ clinical_evidence_score=5,
99
+ clinical_reasoning="Test clinical reasoning that is long enough",
100
+ drug_candidates=[],
101
+ key_findings=[],
102
+ ),
103
+ sufficient=True,
104
+ confidence=0.95,
105
+ recommendation="synthesize",
106
+ reasoning="Evidence is sufficient to provide a comprehensive answer.",
107
+ )
108
+ )
109
+
110
+ should_complete, reason = await manager.check_loop_completion(
111
+ "loop1", "Test query", mock_judge_handler
112
+ )
113
+
114
+ assert should_complete is True
115
+ assert "sufficient" in reason.lower() or "judge" in reason.lower()
116
+ assert mock_judge_handler.assess.called
117
+
118
+ async def test_check_loop_completion_with_insufficient_evidence(self, mock_judge_handler):
119
+ """check_loop_completion should return False when judge says insufficient."""
120
+ manager = WorkflowManager()
121
+ await manager.add_loop("loop1", "Test query")
122
+
123
+ # Add minimal evidence
124
+ evidence = [
125
+ Evidence(
126
+ content="Minimal evidence",
127
+ citation=Citation(
128
+ source="rag", # Use valid SourceName
129
+ title="Test",
130
+ url="https://example.com",
131
+ date="2024-01-01",
132
+ authors=[],
133
+ ),
134
+ relevance=0.3,
135
+ )
136
+ ]
137
+ await manager.add_loop_evidence("loop1", evidence)
138
+
139
+ # Mock judge to say insufficient
140
+ from src.utils.models import AssessmentDetails
141
+
142
+ mock_judge_handler.assess = AsyncMock(
143
+ return_value=JudgeAssessment(
144
+ details=AssessmentDetails(
145
+ mechanism_score=3,
146
+ mechanism_reasoning="Test mechanism reasoning that is long enough",
147
+ clinical_evidence_score=3,
148
+ clinical_reasoning="Test clinical reasoning that is long enough",
149
+ drug_candidates=[],
150
+ key_findings=[],
151
+ ),
152
+ sufficient=False,
153
+ confidence=0.4,
154
+ recommendation="continue",
155
+ reasoning="Need more evidence to provide a comprehensive answer.",
156
+ )
157
+ )
158
+
159
+ should_complete, reason = await manager.check_loop_completion(
160
+ "loop1", "Test query", mock_judge_handler
161
+ )
162
+
163
+ assert should_complete is False
164
+ assert "judge" in reason.lower() or "evidence" in reason.lower()
165
+ assert mock_judge_handler.assess.called
166
+
167
+ async def test_check_loop_completion_with_no_evidence(self, mock_judge_handler):
168
+ """check_loop_completion should return False when no evidence exists."""
169
+ manager = WorkflowManager()
170
+ await manager.add_loop("loop1", "Test query")
171
+
172
+ # Don't add any evidence
173
+
174
+ should_complete, reason = await manager.check_loop_completion(
175
+ "loop1", "Test query", mock_judge_handler
176
+ )
177
+
178
+ assert should_complete is False
179
+ assert "no evidence" in reason.lower() or "not" in reason.lower()
180
+ # Judge should not be called if no evidence
181
+ assert not mock_judge_handler.assess.called
182
+
183
+ async def test_check_loop_completion_handles_judge_error(self, mock_judge_handler):
184
+ """check_loop_completion should handle judge errors gracefully."""
185
+ manager = WorkflowManager()
186
+ await manager.add_loop("loop1", "Test query")
187
+
188
+ evidence = [
189
+ Evidence(
190
+ content="Test evidence",
191
+ citation=Citation(
192
+ source="rag", # Use valid SourceName
193
+ title="Test",
194
+ url="https://example.com",
195
+ date="2024-01-01",
196
+ authors=[],
197
+ ),
198
+ relevance=0.8,
199
+ )
200
+ ]
201
+ await manager.add_loop_evidence("loop1", evidence)
202
+
203
+ # Mock judge to raise error
204
+ mock_judge_handler.assess = AsyncMock(side_effect=Exception("Judge error"))
205
+
206
+ should_complete, reason = await manager.check_loop_completion(
207
+ "loop1", "Test query", mock_judge_handler
208
+ )
209
+
210
+ assert should_complete is False
211
+ assert "error" in reason.lower() or "failed" in reason.lower()
212
+
213
+ async def test_parallel_loops_with_judge_early_termination(
214
+ self, mock_judge_handler, mock_iterative_flow
215
+ ):
216
+ """Parallel loops should terminate early when judge says sufficient."""
217
+ manager = WorkflowManager()
218
+
219
+ # Create multiple loops
220
+ loop_configs = [
221
+ {"loop_id": "loop1", "query": "Query 1"},
222
+ {"loop_id": "loop2", "query": "Query 2"},
223
+ ]
224
+
225
+ # Define loop function that extracts loop_func from config if needed
226
+ async def loop_func(config: dict) -> str:
227
+ return await mock_iterative_flow.run(config.get("query", ""))
228
+
229
+ # Add evidence to loop1 that will trigger early completion
230
+ await manager.add_loop("loop1", "Query 1")
231
+ evidence = [
232
+ Evidence(
233
+ content="Comprehensive evidence for query 1",
234
+ citation=Citation(
235
+ source="rag", # Use valid SourceName
236
+ title="Test",
237
+ url="https://example.com",
238
+ date="2024-01-01",
239
+ authors=[],
240
+ ),
241
+ relevance=0.95,
242
+ )
243
+ ]
244
+ await manager.add_loop_evidence("loop1", evidence)
245
+
246
+ # Mock judge to say sufficient for loop1
247
+ call_count = {"count": 0}
248
+
249
+ def mock_assess(query: str, evidence_list: list[Evidence]) -> JudgeAssessment:
250
+ from src.utils.models import AssessmentDetails
251
+
252
+ call_count["count"] += 1
253
+ if "Query 1" in query or len(evidence_list) > 0:
254
+ return JudgeAssessment(
255
+ details=AssessmentDetails(
256
+ mechanism_score=5,
257
+ mechanism_reasoning="Test mechanism reasoning that is long enough",
258
+ clinical_evidence_score=5,
259
+ clinical_reasoning="Test clinical reasoning that is long enough",
260
+ drug_candidates=[],
261
+ key_findings=[],
262
+ ),
263
+ sufficient=True,
264
+ confidence=0.95,
265
+ recommendation="synthesize",
266
+ reasoning="Sufficient evidence has been collected to answer the query.",
267
+ )
268
+ return JudgeAssessment(
269
+ details=AssessmentDetails(
270
+ mechanism_score=3,
271
+ mechanism_reasoning="Test mechanism reasoning that is long enough",
272
+ clinical_evidence_score=3,
273
+ clinical_reasoning="Test clinical reasoning that is long enough",
274
+ drug_candidates=[],
275
+ key_findings=[],
276
+ ),
277
+ sufficient=False,
278
+ confidence=0.5,
279
+ recommendation="continue",
280
+ reasoning="Need more evidence to provide a comprehensive answer.",
281
+ )
282
+
283
+ mock_judge_handler.assess = AsyncMock(side_effect=mock_assess)
284
+
285
+ # Run loops in parallel
286
+ with patch("src.middleware.workflow_manager.get_workflow_state") as mock_state:
287
+ mock_state_obj = MagicMock()
288
+ mock_state_obj.evidence = []
289
+ mock_state.return_value = mock_state_obj
290
+
291
+ results = await manager.run_loops_parallel(
292
+ loop_configs, loop_func=loop_func, judge_handler=mock_judge_handler
293
+ )
294
+
295
+ # Both loops should complete
296
+ assert len(results) == 2
297
+ assert all(isinstance(r, str) for r in results)
298
+
299
+ async def test_parallel_loops_aggregate_evidence(self, mock_judge_handler):
300
+ """Parallel loops should aggregate evidence from all loops."""
301
+ manager = WorkflowManager()
302
+
303
+ # Create loops
304
+ await manager.add_loop("loop1", "Query 1")
305
+ await manager.add_loop("loop2", "Query 2")
306
+
307
+ # Add evidence to each loop
308
+ evidence1 = [
309
+ Evidence(
310
+ content="Evidence from loop 1",
311
+ citation=Citation(
312
+ source="rag", # Use valid SourceName
313
+ title="Test 1",
314
+ url="https://example.com/1",
315
+ date="2024-01-01",
316
+ authors=[],
317
+ ),
318
+ relevance=0.8,
319
+ )
320
+ ]
321
+ evidence2 = [
322
+ Evidence(
323
+ content="Evidence from loop 2",
324
+ citation=Citation(
325
+ source="rag", # Use valid SourceName
326
+ title="Test 2",
327
+ url="https://example.com/2",
328
+ date="2024-01-01",
329
+ authors=[],
330
+ ),
331
+ relevance=0.9,
332
+ )
333
+ ]
334
+
335
+ await manager.add_loop_evidence("loop1", evidence1)
336
+ await manager.add_loop_evidence("loop2", evidence2)
337
+
338
+ # Get evidence from both loops
339
+ evidence1_retrieved = await manager.get_loop_evidence("loop1")
340
+ evidence2_retrieved = await manager.get_loop_evidence("loop2")
341
+
342
+ assert len(evidence1_retrieved) == 1
343
+ assert len(evidence2_retrieved) == 1
344
+ assert evidence1_retrieved[0].content == "Evidence from loop 1"
345
+ assert evidence2_retrieved[0].content == "Evidence from loop 2"
346
+
347
+ async def test_loop_status_updated_on_completion(self, mock_judge_handler):
348
+ """Loop status should be updated when judge determines completion."""
349
+ manager = WorkflowManager()
350
+ await manager.add_loop("loop1", "Test query")
351
+
352
+ # Add sufficient evidence
353
+ evidence = [
354
+ Evidence(
355
+ content="Sufficient evidence",
356
+ citation=Citation(
357
+ source="rag", # Use valid SourceName
358
+ title="Test",
359
+ url="https://example.com",
360
+ date="2024-01-01",
361
+ authors=[],
362
+ ),
363
+ relevance=0.95,
364
+ )
365
+ ]
366
+ await manager.add_loop_evidence("loop1", evidence)
367
+
368
+ from src.utils.models import AssessmentDetails
369
+
370
+ mock_judge_handler.assess = AsyncMock(
371
+ return_value=JudgeAssessment(
372
+ details=AssessmentDetails(
373
+ mechanism_score=5,
374
+ mechanism_reasoning="Test mechanism reasoning that is long enough",
375
+ clinical_evidence_score=5,
376
+ clinical_reasoning="Test clinical reasoning that is long enough",
377
+ drug_candidates=[],
378
+ key_findings=[],
379
+ ),
380
+ sufficient=True,
381
+ confidence=0.95,
382
+ recommendation="synthesize",
383
+ reasoning="Complete evidence has been collected to answer the query.",
384
+ )
385
+ )
386
+
387
+ # Check completion (this should update status internally if implemented)
388
+ should_complete, _ = await manager.check_loop_completion(
389
+ "loop1", "Test query", mock_judge_handler
390
+ )
391
+
392
+ assert should_complete is True
393
+ # Status update would happen in run_loops_parallel, not in check_loop_completion
394
+ loop = await manager.get_loop("loop1")
395
+ assert loop is not None
396
+ # Status might still be "pending" or "running" until run_loops_parallel updates it
tests/integration/test_rag_integration.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Integration tests for RAG integration.
2
+
3
+ These tests require OPENAI_API_KEY and may make real API calls.
4
+ Marked with @pytest.mark.integration to skip in unit test runs.
5
+ """
6
+
7
+ import pytest
8
+
9
+ from src.services.llamaindex_rag import get_rag_service
10
+ from src.tools.rag_tool import create_rag_tool
11
+ from src.tools.search_handler import SearchHandler
12
+ from src.tools.tool_executor import execute_agent_task
13
+ from src.utils.config import settings
14
+ from src.utils.models import AgentTask, Citation, Evidence
15
+
16
+
17
+ @pytest.mark.integration
18
+ class TestRAGServiceIntegration:
19
+ """Integration tests for LlamaIndexRAGService."""
20
+
21
+ @pytest.mark.asyncio
22
+ async def test_rag_service_ingest_and_retrieve(self):
23
+ """RAG service should ingest and retrieve evidence."""
24
+ if not settings.openai_api_key:
25
+ pytest.skip("OPENAI_API_KEY required for RAG integration tests")
26
+
27
+ # Create RAG service
28
+ rag_service = get_rag_service(collection_name="test_integration")
29
+
30
+ # Create sample evidence
31
+ evidence_list = [
32
+ Evidence(
33
+ content="Metformin is a first-line treatment for type 2 diabetes. It works by reducing glucose production in the liver and improving insulin sensitivity.",
34
+ citation=Citation(
35
+ source="pubmed",
36
+ title="Metformin Mechanism of Action",
37
+ url="https://pubmed.ncbi.nlm.nih.gov/12345678/",
38
+ date="2024-01-15",
39
+ authors=["Smith J", "Johnson M"],
40
+ ),
41
+ relevance=0.9,
42
+ ),
43
+ Evidence(
44
+ content="Recent studies suggest metformin may have neuroprotective effects in Alzheimer's disease models.",
45
+ citation=Citation(
46
+ source="pubmed",
47
+ title="Metformin and Neuroprotection",
48
+ url="https://pubmed.ncbi.nlm.nih.gov/12345679/",
49
+ date="2024-02-20",
50
+ authors=["Brown K", "Davis L"],
51
+ ),
52
+ relevance=0.85,
53
+ ),
54
+ ]
55
+
56
+ # Ingest evidence
57
+ rag_service.ingest_evidence(evidence_list)
58
+
59
+ # Retrieve evidence
60
+ results = rag_service.retrieve("metformin diabetes", top_k=2)
61
+
62
+ # Assert
63
+ assert len(results) > 0
64
+ assert any("metformin" in r["text"].lower() for r in results)
65
+ assert all("text" in r for r in results)
66
+ assert all("metadata" in r for r in results)
67
+
68
+ # Cleanup
69
+ rag_service.clear_collection()
70
+
71
+ @pytest.mark.asyncio
72
+ async def test_rag_service_query(self):
73
+ """RAG service should synthesize responses from ingested evidence."""
74
+ if not settings.openai_api_key:
75
+ pytest.skip("OPENAI_API_KEY required for RAG integration tests")
76
+
77
+ rag_service = get_rag_service(collection_name="test_query")
78
+
79
+ # Ingest evidence
80
+ evidence_list = [
81
+ Evidence(
82
+ content="Python is a high-level programming language known for its simplicity and readability.",
83
+ citation=Citation(
84
+ source="pubmed",
85
+ title="Python Programming",
86
+ url="https://example.com/python",
87
+ date="2024",
88
+ authors=["Author"],
89
+ ),
90
+ )
91
+ ]
92
+ rag_service.ingest_evidence(evidence_list)
93
+
94
+ # Query
95
+ response = rag_service.query("What is Python?", top_k=1)
96
+
97
+ assert isinstance(response, str)
98
+ assert len(response) > 0
99
+ assert "python" in response.lower()
100
+
101
+ # Cleanup
102
+ rag_service.clear_collection()
103
+
104
+
105
+ @pytest.mark.integration
106
+ class TestRAGToolIntegration:
107
+ """Integration tests for RAGTool."""
108
+
109
+ @pytest.mark.asyncio
110
+ async def test_rag_tool_search(self):
111
+ """RAGTool should search RAG service and return Evidence objects."""
112
+ if not settings.openai_api_key:
113
+ pytest.skip("OPENAI_API_KEY required for RAG integration tests")
114
+
115
+ # Create RAG service and ingest evidence
116
+ rag_service = get_rag_service(collection_name="test_rag_tool")
117
+ evidence_list = [
118
+ Evidence(
119
+ content="Machine learning is a subset of artificial intelligence.",
120
+ citation=Citation(
121
+ source="pubmed",
122
+ title="ML Basics",
123
+ url="https://example.com/ml",
124
+ date="2024",
125
+ authors=["ML Expert"],
126
+ ),
127
+ )
128
+ ]
129
+ rag_service.ingest_evidence(evidence_list)
130
+
131
+ # Create RAG tool
132
+ tool = create_rag_tool(rag_service=rag_service)
133
+
134
+ # Search
135
+ results = await tool.search("machine learning", max_results=5)
136
+
137
+ # Assert
138
+ assert len(results) > 0
139
+ assert all(isinstance(e, Evidence) for e in results)
140
+ assert results[0].citation.source == "rag"
141
+ assert (
142
+ "machine learning" in results[0].content.lower()
143
+ or "artificial intelligence" in results[0].content.lower()
144
+ )
145
+
146
+ # Cleanup
147
+ rag_service.clear_collection()
148
+
149
+ @pytest.mark.asyncio
150
+ async def test_rag_tool_empty_collection(self):
151
+ """RAGTool should return empty list when collection is empty."""
152
+ if not settings.openai_api_key:
153
+ pytest.skip("OPENAI_API_KEY required for RAG integration tests")
154
+
155
+ rag_service = get_rag_service(collection_name="test_empty")
156
+ rag_service.clear_collection() # Ensure empty
157
+
158
+ tool = create_rag_tool(rag_service=rag_service)
159
+ results = await tool.search("any query")
160
+
161
+ assert results == []
162
+
163
+
164
+ @pytest.mark.integration
165
+ class TestRAGAgentIntegration:
166
+ """Integration tests for RAGAgent in tool executor."""
167
+
168
+ @pytest.mark.asyncio
169
+ async def test_rag_agent_execution(self):
170
+ """RAGAgent should execute and return ToolAgentOutput."""
171
+ if not settings.openai_api_key:
172
+ pytest.skip("OPENAI_API_KEY required for RAG integration tests")
173
+
174
+ # Setup: Ingest evidence into RAG
175
+ rag_service = get_rag_service(collection_name="test_rag_agent")
176
+ evidence_list = [
177
+ Evidence(
178
+ content="Deep learning uses neural networks with multiple layers.",
179
+ citation=Citation(
180
+ source="pubmed",
181
+ title="Deep Learning",
182
+ url="https://example.com/dl",
183
+ date="2024",
184
+ authors=["DL Researcher"],
185
+ ),
186
+ )
187
+ ]
188
+ rag_service.ingest_evidence(evidence_list)
189
+
190
+ # Execute RAGAgent task
191
+ task = AgentTask(
192
+ agent="RAGAgent",
193
+ query="deep learning",
194
+ gap="Need information about deep learning",
195
+ )
196
+
197
+ result = await execute_agent_task(task)
198
+
199
+ # Assert
200
+ assert result.output
201
+ assert "deep learning" in result.output.lower() or "neural network" in result.output.lower()
202
+ assert len(result.sources) > 0
203
+
204
+ # Cleanup
205
+ rag_service.clear_collection()
206
+
207
+
208
+ @pytest.mark.integration
209
+ class TestRAGSearchHandlerIntegration:
210
+ """Integration tests for RAG in SearchHandler."""
211
+
212
+ @pytest.mark.asyncio
213
+ async def test_search_handler_with_rag(self):
214
+ """SearchHandler should work with RAG tool included."""
215
+ if not settings.openai_api_key:
216
+ pytest.skip("OPENAI_API_KEY required for RAG integration tests")
217
+
218
+ # Setup: Create RAG service and ingest some evidence
219
+ rag_service = get_rag_service(collection_name="test_search_handler")
220
+ evidence_list = [
221
+ Evidence(
222
+ content="Test evidence for search handler integration.",
223
+ citation=Citation(
224
+ source="pubmed",
225
+ title="Test Evidence",
226
+ url="https://example.com/test",
227
+ date="2024",
228
+ authors=["Tester"],
229
+ ),
230
+ )
231
+ ]
232
+ rag_service.ingest_evidence(evidence_list)
233
+
234
+ # Create SearchHandler with RAG
235
+ handler = SearchHandler(
236
+ tools=[], # No other tools
237
+ include_rag=True,
238
+ auto_ingest_to_rag=False, # Don't auto-ingest (already has data)
239
+ )
240
+
241
+ # Execute search
242
+ result = await handler.execute("test evidence", max_results_per_tool=5)
243
+
244
+ # Assert
245
+ assert result.total_found > 0
246
+ assert "rag" in result.sources_searched
247
+ assert any(e.citation.source == "rag" for e in result.evidence)
248
+
249
+ # Cleanup
250
+ rag_service.clear_collection()
251
+
252
+ @pytest.mark.asyncio
253
+ async def test_search_handler_auto_ingest(self):
254
+ """SearchHandler should auto-ingest evidence into RAG."""
255
+ if not settings.openai_api_key:
256
+ pytest.skip("OPENAI_API_KEY required for RAG integration tests")
257
+
258
+ # Create empty RAG service
259
+ rag_service = get_rag_service(collection_name="test_auto_ingest")
260
+ rag_service.clear_collection()
261
+
262
+ # Create mock tool that returns evidence
263
+ from unittest.mock import AsyncMock
264
+
265
+ mock_tool = AsyncMock()
266
+ mock_tool.name = "pubmed"
267
+ mock_tool.search = AsyncMock(
268
+ return_value=[
269
+ Evidence(
270
+ content="Evidence to be ingested",
271
+ citation=Citation(
272
+ source="pubmed",
273
+ title="Test",
274
+ url="https://example.com",
275
+ date="2024",
276
+ authors=[],
277
+ ),
278
+ )
279
+ ]
280
+ )
281
+
282
+ # Create handler with auto-ingest enabled
283
+ handler = SearchHandler(
284
+ tools=[mock_tool],
285
+ include_rag=False, # Don't include RAG as search tool
286
+ auto_ingest_to_rag=True,
287
+ )
288
+ handler._rag_service = rag_service # Inject RAG service
289
+
290
+ # Execute search
291
+ await handler.execute("test query")
292
+
293
+ # Verify evidence was ingested
294
+ rag_results = rag_service.retrieve("Evidence to be ingested", top_k=1)
295
+ assert len(rag_results) > 0
296
+
297
+ # Cleanup
298
+ rag_service.clear_collection()
299
+
300
+
301
+ @pytest.mark.integration
302
+ class TestRAGHybridSearchIntegration:
303
+ """Integration tests for hybrid search (RAG + database)."""
304
+
305
+ @pytest.mark.asyncio
306
+ async def test_hybrid_search_rag_and_pubmed(self):
307
+ """SearchHandler should support RAG + PubMed hybrid search."""
308
+ if not settings.openai_api_key:
309
+ pytest.skip("OPENAI_API_KEY required for RAG integration tests")
310
+
311
+ # Setup: Ingest evidence into RAG
312
+ rag_service = get_rag_service(collection_name="test_hybrid")
313
+ evidence_list = [
314
+ Evidence(
315
+ content="Previously collected evidence about metformin.",
316
+ citation=Citation(
317
+ source="pubmed",
318
+ title="Previous Research",
319
+ url="https://example.com/prev",
320
+ date="2024",
321
+ authors=[],
322
+ ),
323
+ )
324
+ ]
325
+ rag_service.ingest_evidence(evidence_list)
326
+
327
+ # Note: This test would require real PubMed API access
328
+ # For now, we'll just test that the handler can be created with both tools
329
+ from src.tools.pubmed import PubMedTool
330
+
331
+ handler = SearchHandler(
332
+ tools=[PubMedTool()],
333
+ include_rag=True,
334
+ auto_ingest_to_rag=True,
335
+ )
336
+
337
+ # Verify handler has both tools
338
+ tool_names = [t.name for t in handler.tools]
339
+ assert "pubmed" in tool_names
340
+ assert "rag" in tool_names
341
+
342
+ # Cleanup
343
+ rag_service.clear_collection()
tests/integration/test_research_flows.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Integration tests for research flows.
2
+
3
+ These tests require API keys and may make real API calls.
4
+ Marked with @pytest.mark.integration to skip in unit test runs.
5
+ """
6
+
7
+ import pytest
8
+
9
+ from src.agent_factory.agents import (
10
+ create_deep_flow,
11
+ create_iterative_flow,
12
+ create_planner_agent,
13
+ )
14
+ from src.orchestrator.graph_orchestrator import create_graph_orchestrator
15
+ from src.utils.config import settings
16
+
17
+
18
+ @pytest.mark.integration
19
+ class TestPlannerAgentIntegration:
20
+ """Integration tests for PlannerAgent with real API calls."""
21
+
22
+ @pytest.mark.asyncio
23
+ async def test_planner_agent_creates_plan(self):
24
+ """PlannerAgent should create a valid report plan with real API."""
25
+ if not settings.has_openai_key and not settings.has_anthropic_key:
26
+ pytest.skip("No OpenAI or Anthropic API key available")
27
+
28
+ planner = create_planner_agent()
29
+ result = await planner.run("What are the main features of Python programming language?")
30
+
31
+ assert result.report_title
32
+ assert len(result.report_outline) > 0
33
+ assert result.report_outline[0].title
34
+ assert result.report_outline[0].key_question
35
+
36
+ @pytest.mark.asyncio
37
+ async def test_planner_agent_includes_background_context(self):
38
+ """PlannerAgent should include background context in plan."""
39
+ if not settings.has_openai_key and not settings.has_anthropic_key:
40
+ pytest.skip("No OpenAI or Anthropic API key available")
41
+
42
+ planner = create_planner_agent()
43
+ result = await planner.run("Explain quantum computing basics")
44
+
45
+ assert result.background_context
46
+ assert len(result.background_context) > 50 # Should have substantial context
47
+
48
+
49
+ @pytest.mark.integration
50
+ class TestIterativeResearchFlowIntegration:
51
+ """Integration tests for IterativeResearchFlow with real API calls."""
52
+
53
+ @pytest.mark.asyncio
54
+ async def test_iterative_flow_completes_simple_query(self):
55
+ """IterativeResearchFlow should complete a simple research query."""
56
+ if not settings.has_openai_key and not settings.has_anthropic_key:
57
+ pytest.skip("No OpenAI or Anthropic API key available")
58
+
59
+ flow = create_iterative_flow(max_iterations=2, max_time_minutes=2)
60
+ result = await flow.run(
61
+ query="What is the capital of France?",
62
+ output_length="A short paragraph",
63
+ )
64
+
65
+ assert isinstance(result, str)
66
+ assert len(result) > 0
67
+ # Should mention Paris
68
+ assert "paris" in result.lower() or "france" in result.lower()
69
+
70
+ @pytest.mark.asyncio
71
+ async def test_iterative_flow_respects_max_iterations(self):
72
+ """IterativeResearchFlow should respect max_iterations limit."""
73
+ if not settings.has_openai_key and not settings.has_anthropic_key:
74
+ pytest.skip("No OpenAI or Anthropic API key available")
75
+
76
+ flow = create_iterative_flow(max_iterations=1, max_time_minutes=5)
77
+ result = await flow.run(query="What are the main features of Python?")
78
+
79
+ assert isinstance(result, str)
80
+ # Should complete within 1 iteration (or hit max)
81
+ assert flow.iteration <= 1
82
+
83
+ @pytest.mark.asyncio
84
+ async def test_iterative_flow_with_background_context(self):
85
+ """IterativeResearchFlow should use background context."""
86
+ if not settings.has_openai_key and not settings.has_anthropic_key:
87
+ pytest.skip("No OpenAI or Anthropic API key available")
88
+
89
+ flow = create_iterative_flow(max_iterations=2, max_time_minutes=2)
90
+ result = await flow.run(
91
+ query="What is machine learning?",
92
+ background_context="Machine learning is a subset of artificial intelligence.",
93
+ )
94
+
95
+ assert isinstance(result, str)
96
+ assert len(result) > 0
97
+
98
+
99
+ @pytest.mark.integration
100
+ class TestDeepResearchFlowIntegration:
101
+ """Integration tests for DeepResearchFlow with real API calls."""
102
+
103
+ @pytest.mark.asyncio
104
+ async def test_deep_flow_creates_multi_section_report(self):
105
+ """DeepResearchFlow should create a report with multiple sections."""
106
+ if not settings.has_openai_key and not settings.has_anthropic_key:
107
+ pytest.skip("No OpenAI or Anthropic API key available")
108
+
109
+ flow = create_deep_flow(
110
+ max_iterations=1, # Keep it short for testing
111
+ max_time_minutes=3,
112
+ )
113
+ result = await flow.run("What are the main features of Python programming language?")
114
+
115
+ assert isinstance(result, str)
116
+ assert len(result) > 100 # Should have substantial content
117
+ # Should have section structure
118
+ assert "#" in result or "##" in result
119
+
120
+ @pytest.mark.asyncio
121
+ async def test_deep_flow_uses_long_writer(self):
122
+ """DeepResearchFlow should use long writer by default."""
123
+ if not settings.has_openai_key and not settings.has_anthropic_key:
124
+ pytest.skip("No OpenAI or Anthropic API key available")
125
+
126
+ flow = create_deep_flow(
127
+ max_iterations=1,
128
+ max_time_minutes=3,
129
+ use_long_writer=True,
130
+ )
131
+ result = await flow.run("Explain the basics of quantum computing")
132
+
133
+ assert isinstance(result, str)
134
+ assert len(result) > 0
135
+
136
+ @pytest.mark.asyncio
137
+ async def test_deep_flow_uses_proofreader_when_specified(self):
138
+ """DeepResearchFlow should use proofreader when use_long_writer=False."""
139
+ if not settings.has_openai_key and not settings.has_anthropic_key:
140
+ pytest.skip("No OpenAI or Anthropic API key available")
141
+
142
+ flow = create_deep_flow(
143
+ max_iterations=1,
144
+ max_time_minutes=3,
145
+ use_long_writer=False,
146
+ )
147
+ result = await flow.run("What is artificial intelligence?")
148
+
149
+ assert isinstance(result, str)
150
+ assert len(result) > 0
151
+
152
+
153
+ @pytest.mark.integration
154
+ class TestGraphOrchestratorIntegration:
155
+ """Integration tests for GraphOrchestrator with real API calls."""
156
+
157
+ @pytest.mark.asyncio
158
+ async def test_graph_orchestrator_iterative_mode(self):
159
+ """GraphOrchestrator should run in iterative mode."""
160
+ if not settings.has_openai_key and not settings.has_anthropic_key:
161
+ pytest.skip("No OpenAI or Anthropic API key available")
162
+
163
+ orchestrator = create_graph_orchestrator(
164
+ mode="iterative",
165
+ max_iterations=1,
166
+ max_time_minutes=2,
167
+ )
168
+
169
+ events = []
170
+ async for event in orchestrator.run("What is Python?"):
171
+ events.append(event)
172
+
173
+ assert len(events) > 0
174
+ event_types = [e.type for e in events]
175
+ assert "started" in event_types
176
+ assert "complete" in event_types
177
+
178
+ @pytest.mark.asyncio
179
+ async def test_graph_orchestrator_deep_mode(self):
180
+ """GraphOrchestrator should run in deep mode."""
181
+ if not settings.has_openai_key and not settings.has_anthropic_key:
182
+ pytest.skip("No OpenAI or Anthropic API key available")
183
+
184
+ orchestrator = create_graph_orchestrator(
185
+ mode="deep",
186
+ max_iterations=1,
187
+ max_time_minutes=3,
188
+ )
189
+
190
+ events = []
191
+ async for event in orchestrator.run("What are the main features of Python?"):
192
+ events.append(event)
193
+
194
+ assert len(events) > 0
195
+ event_types = [e.type for e in events]
196
+ assert "started" in event_types
197
+ assert "complete" in event_types
198
+
199
+ @pytest.mark.asyncio
200
+ async def test_graph_orchestrator_auto_mode(self):
201
+ """GraphOrchestrator should auto-detect research mode."""
202
+ if not settings.has_openai_key and not settings.has_anthropic_key:
203
+ pytest.skip("No OpenAI or Anthropic API key available")
204
+
205
+ orchestrator = create_graph_orchestrator(
206
+ mode="auto",
207
+ max_iterations=1,
208
+ max_time_minutes=2,
209
+ )
210
+
211
+ events = []
212
+ async for event in orchestrator.run("What is Python?"):
213
+ events.append(event)
214
+
215
+ assert len(events) > 0
216
+ # Should complete successfully regardless of mode
217
+ event_types = [e.type for e in events]
218
+ assert "complete" in event_types
219
+
220
+
221
+ @pytest.mark.integration
222
+ class TestGraphOrchestrationIntegration:
223
+ """Integration tests for graph-based orchestration with real API calls."""
224
+
225
+ @pytest.mark.asyncio
226
+ async def test_iterative_flow_with_graph_execution(self):
227
+ """IterativeResearchFlow should work with graph execution enabled."""
228
+ if not settings.has_openai_key and not settings.has_anthropic_key:
229
+ pytest.skip("No OpenAI or Anthropic API key available")
230
+
231
+ flow = create_iterative_flow(
232
+ max_iterations=1,
233
+ max_time_minutes=2,
234
+ use_graph=True,
235
+ )
236
+ result = await flow.run(query="What is the capital of France?")
237
+
238
+ assert isinstance(result, str)
239
+ assert len(result) > 0
240
+ # Should mention Paris
241
+ assert "paris" in result.lower() or "france" in result.lower()
242
+
243
+ @pytest.mark.asyncio
244
+ async def test_deep_flow_with_graph_execution(self):
245
+ """DeepResearchFlow should work with graph execution enabled."""
246
+ if not settings.has_openai_key and not settings.has_anthropic_key:
247
+ pytest.skip("No OpenAI or Anthropic API key available")
248
+
249
+ flow = create_deep_flow(
250
+ max_iterations=1,
251
+ max_time_minutes=3,
252
+ use_graph=True,
253
+ )
254
+ result = await flow.run("What are the main features of Python programming language?")
255
+
256
+ assert isinstance(result, str)
257
+ assert len(result) > 100 # Should have substantial content
258
+
259
+ @pytest.mark.asyncio
260
+ async def test_graph_orchestrator_with_graph_execution(self):
261
+ """GraphOrchestrator should work with graph execution enabled."""
262
+ if not settings.has_openai_key and not settings.has_anthropic_key:
263
+ pytest.skip("No OpenAI or Anthropic API key available")
264
+
265
+ orchestrator = create_graph_orchestrator(
266
+ mode="iterative",
267
+ max_iterations=1,
268
+ max_time_minutes=2,
269
+ use_graph=True,
270
+ )
271
+
272
+ events = []
273
+ async for event in orchestrator.run("What is Python?"):
274
+ events.append(event)
275
+
276
+ assert len(events) > 0
277
+ event_types = [e.type for e in events]
278
+ assert "started" in event_types
279
+ assert "complete" in event_types
280
+
281
+ # Extract final report from complete event
282
+ complete_events = [e for e in events if e.type == "complete"]
283
+ assert len(complete_events) > 0
284
+ final_report = complete_events[0].message
285
+ assert isinstance(final_report, str)
286
+ assert len(final_report) > 0
287
+
288
+ @pytest.mark.asyncio
289
+ async def test_graph_orchestrator_parallel_execution(self):
290
+ """GraphOrchestrator should support parallel execution in deep mode."""
291
+ if not settings.has_openai_key and not settings.has_anthropic_key:
292
+ pytest.skip("No OpenAI or Anthropic API key available")
293
+
294
+ orchestrator = create_graph_orchestrator(
295
+ mode="deep",
296
+ max_iterations=1,
297
+ max_time_minutes=3,
298
+ use_graph=True,
299
+ )
300
+
301
+ events = []
302
+ async for event in orchestrator.run("What are the main features of Python?"):
303
+ events.append(event)
304
+
305
+ assert len(events) > 0
306
+ event_types = [e.type for e in events]
307
+ assert "started" in event_types
308
+ assert "complete" in event_types
309
+
310
+ @pytest.mark.asyncio
311
+ async def test_graph_vs_chain_execution_comparison(self):
312
+ """Both graph and chain execution should produce similar results."""
313
+ if not settings.has_openai_key and not settings.has_anthropic_key:
314
+ pytest.skip("No OpenAI or Anthropic API key available")
315
+
316
+ query = "What is the capital of France?"
317
+
318
+ # Run with graph execution
319
+ flow_graph = create_iterative_flow(
320
+ max_iterations=1,
321
+ max_time_minutes=2,
322
+ use_graph=True,
323
+ )
324
+ result_graph = await flow_graph.run(query=query)
325
+
326
+ # Run with agent chains
327
+ flow_chains = create_iterative_flow(
328
+ max_iterations=1,
329
+ max_time_minutes=2,
330
+ use_graph=False,
331
+ )
332
+ result_chains = await flow_chains.run(query=query)
333
+
334
+ # Both should produce valid results
335
+ assert isinstance(result_graph, str)
336
+ assert isinstance(result_chains, str)
337
+ assert len(result_graph) > 0
338
+ assert len(result_chains) > 0
339
+
340
+ # Both should mention the answer (Paris)
341
+ assert "paris" in result_graph.lower() or "france" in result_graph.lower()
342
+ assert "paris" in result_chains.lower() or "france" in result_chains.lower()
343
+
344
+
345
+ @pytest.mark.integration
346
+ class TestReportSynthesisIntegration:
347
+ """Integration tests for report synthesis with writer agents."""
348
+
349
+ @pytest.mark.asyncio
350
+ async def test_iterative_flow_generates_report(self):
351
+ """IterativeResearchFlow should generate a report with writer agent."""
352
+ if not settings.has_openai_key and not settings.has_anthropic_key:
353
+ pytest.skip("No OpenAI or Anthropic API key available")
354
+
355
+ flow = create_iterative_flow(max_iterations=1, max_time_minutes=2)
356
+ result = await flow.run(
357
+ query="What is the capital of France?",
358
+ output_length="A short paragraph",
359
+ )
360
+
361
+ assert isinstance(result, str)
362
+ assert len(result) > 0
363
+ # Should be a formatted report
364
+ assert "paris" in result.lower() or "france" in result.lower()
365
+ # Should have some structure (markdown headers or content)
366
+ assert len(result) > 50
367
+
368
+ @pytest.mark.asyncio
369
+ async def test_iterative_flow_includes_citations(self):
370
+ """IterativeResearchFlow should include citations in the report."""
371
+ if not settings.has_openai_key and not settings.has_anthropic_key:
372
+ pytest.skip("No OpenAI or Anthropic API key available")
373
+
374
+ flow = create_iterative_flow(max_iterations=1, max_time_minutes=2)
375
+ result = await flow.run(
376
+ query="What is machine learning?",
377
+ output_length="A short paragraph",
378
+ )
379
+
380
+ assert isinstance(result, str)
381
+ # Should have some form of citations or references
382
+ # (either [1], [2] format or References section)
383
+ # Note: Citations may not always be present depending on findings
384
+ # This is a soft check - just verify report was generated
385
+ assert len(result) > 0
386
+
387
+ @pytest.mark.asyncio
388
+ async def test_iterative_flow_handles_empty_findings(self):
389
+ """IterativeResearchFlow should handle empty findings gracefully."""
390
+ if not settings.has_openai_key and not settings.has_anthropic_key:
391
+ pytest.skip("No OpenAI or Anthropic API key available")
392
+
393
+ flow = create_iterative_flow(max_iterations=1, max_time_minutes=1)
394
+ # Use a query that might not return findings quickly
395
+ result = await flow.run(
396
+ query="Test query with no findings",
397
+ output_length="A short paragraph",
398
+ )
399
+
400
+ # Should still return a report (even if minimal)
401
+ assert isinstance(result, str)
402
+ # Writer agent should handle empty findings with fallback
403
+
404
+ @pytest.mark.asyncio
405
+ async def test_deep_flow_with_long_writer(self):
406
+ """DeepResearchFlow should use long writer to create sections."""
407
+ if not settings.has_openai_key and not settings.has_anthropic_key:
408
+ pytest.skip("No OpenAI or Anthropic API key available")
409
+
410
+ flow = create_deep_flow(
411
+ max_iterations=1,
412
+ max_time_minutes=3,
413
+ use_long_writer=True,
414
+ )
415
+ result = await flow.run("What are the main features of Python programming language?")
416
+
417
+ assert isinstance(result, str)
418
+ assert len(result) > 100 # Should have substantial content
419
+ # Should have section structure (table of contents or sections)
420
+ has_structure = (
421
+ "##" in result
422
+ or "#" in result
423
+ or "table of contents" in result.lower()
424
+ or "introduction" in result.lower()
425
+ )
426
+ # Long writer should create structured report
427
+ assert has_structure or len(result) > 200
428
+
429
+ @pytest.mark.asyncio
430
+ async def test_deep_flow_creates_sections(self):
431
+ """DeepResearchFlow should create multiple sections in the report."""
432
+ if not settings.has_openai_key and not settings.has_anthropic_key:
433
+ pytest.skip("No OpenAI or Anthropic API key available")
434
+
435
+ flow = create_deep_flow(
436
+ max_iterations=1,
437
+ max_time_minutes=3,
438
+ use_long_writer=True,
439
+ )
440
+ result = await flow.run("Explain the basics of quantum computing")
441
+
442
+ assert isinstance(result, str)
443
+ # Should have multiple sections (indicated by headers)
444
+ # Should have at least some structure
445
+ assert len(result) > 100
446
+
447
+ @pytest.mark.asyncio
448
+ async def test_deep_flow_aggregates_references(self):
449
+ """DeepResearchFlow should aggregate references from all sections."""
450
+ if not settings.has_openai_key and not settings.has_anthropic_key:
451
+ pytest.skip("No OpenAI or Anthropic API key available")
452
+
453
+ flow = create_deep_flow(
454
+ max_iterations=1,
455
+ max_time_minutes=3,
456
+ use_long_writer=True,
457
+ )
458
+ result = await flow.run("What are the main features of Python programming language?")
459
+
460
+ assert isinstance(result, str)
461
+ # Long writer should aggregate references at the end
462
+ # Check for references section or citation format
463
+ # Note: References may not always be present
464
+ # Just verify report structure is correct
465
+ assert len(result) > 100
466
+
467
+ @pytest.mark.asyncio
468
+ async def test_deep_flow_with_proofreader(self):
469
+ """DeepResearchFlow should use proofreader to finalize report."""
470
+ if not settings.has_openai_key and not settings.has_anthropic_key:
471
+ pytest.skip("No OpenAI or Anthropic API key available")
472
+
473
+ flow = create_deep_flow(
474
+ max_iterations=1,
475
+ max_time_minutes=3,
476
+ use_long_writer=False, # Use proofreader instead
477
+ )
478
+ result = await flow.run("What is artificial intelligence?")
479
+
480
+ assert isinstance(result, str)
481
+ assert len(result) > 0
482
+ # Proofreader should create polished report
483
+ # Should have some structure
484
+ assert len(result) > 50
485
+
486
+ @pytest.mark.asyncio
487
+ async def test_proofreader_removes_duplicates(self):
488
+ """Proofreader should remove duplicate content from report."""
489
+ if not settings.has_openai_key and not settings.has_anthropic_key:
490
+ pytest.skip("No OpenAI or Anthropic API key available")
491
+
492
+ flow = create_deep_flow(
493
+ max_iterations=1,
494
+ max_time_minutes=3,
495
+ use_long_writer=False,
496
+ )
497
+ result = await flow.run("Explain machine learning basics")
498
+
499
+ assert isinstance(result, str)
500
+ # Proofreader should create polished, non-repetitive content
501
+ # This is a soft check - just verify report was generated
502
+ assert len(result) > 0
503
+
504
+ @pytest.mark.asyncio
505
+ async def test_proofreader_adds_summary(self):
506
+ """Proofreader should add a summary to the report."""
507
+ if not settings.has_openai_key and not settings.has_anthropic_key:
508
+ pytest.skip("No OpenAI or Anthropic API key available")
509
+
510
+ flow = create_deep_flow(
511
+ max_iterations=1,
512
+ max_time_minutes=3,
513
+ use_long_writer=False,
514
+ )
515
+ result = await flow.run("What is Python programming language?")
516
+
517
+ assert isinstance(result, str)
518
+ # Proofreader should add summary/outline
519
+ # Check for summary indicators
520
+ # Note: Summary format may vary
521
+ # Just verify report was generated
522
+ assert len(result) > 0
523
+
524
+ @pytest.mark.asyncio
525
+ async def test_graph_orchestrator_uses_writer_agents(self):
526
+ """GraphOrchestrator should use writer agents in iterative mode."""
527
+ if not settings.has_openai_key and not settings.has_anthropic_key:
528
+ pytest.skip("No OpenAI or Anthropic API key available")
529
+
530
+ orchestrator = create_graph_orchestrator(
531
+ mode="iterative",
532
+ max_iterations=1,
533
+ max_time_minutes=2,
534
+ use_graph=False, # Use agent chains to test writer integration
535
+ )
536
+
537
+ events = []
538
+ async for event in orchestrator.run("What is the capital of France?"):
539
+ events.append(event)
540
+
541
+ assert len(events) > 0
542
+ event_types = [e.type for e in events]
543
+ assert "started" in event_types
544
+ assert "complete" in event_types
545
+
546
+ # Extract final report from complete event
547
+ complete_events = [e for e in events if e.type == "complete"]
548
+ assert len(complete_events) > 0
549
+ final_report = complete_events[0].message
550
+ assert isinstance(final_report, str)
551
+ assert len(final_report) > 0
552
+ # Should have content from writer agent
553
+ assert "paris" in final_report.lower() or "france" in final_report.lower()
554
+
555
+ @pytest.mark.asyncio
556
+ async def test_graph_orchestrator_uses_long_writer_in_deep_mode(self):
557
+ """GraphOrchestrator should use long writer in deep mode."""
558
+ if not settings.has_openai_key and not settings.has_anthropic_key:
559
+ pytest.skip("No OpenAI or Anthropic API key available")
560
+
561
+ orchestrator = create_graph_orchestrator(
562
+ mode="deep",
563
+ max_iterations=1,
564
+ max_time_minutes=3,
565
+ use_graph=False, # Use agent chains
566
+ )
567
+
568
+ events = []
569
+ async for event in orchestrator.run("What are the main features of Python?"):
570
+ events.append(event)
571
+
572
+ assert len(events) > 0
573
+ event_types = [e.type for e in events]
574
+ assert "started" in event_types
575
+ assert "complete" in event_types
576
+
577
+ # Extract final report
578
+ complete_events = [e for e in events if e.type == "complete"]
579
+ assert len(complete_events) > 0
580
+ final_report = complete_events[0].message
581
+ assert isinstance(final_report, str)
582
+ assert len(final_report) > 0
583
+ # Should have structured content from long writer
584
+ assert len(final_report) > 100
tests/unit/agent_factory/test_graph_builder.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for graph builder utilities."""
2
+
3
+ from typing import Any
4
+ from unittest.mock import MagicMock
5
+
6
+ import pytest
7
+ from pydantic_ai import Agent
8
+
9
+ from src.agent_factory.graph_builder import (
10
+ AgentNode,
11
+ ConditionalEdge,
12
+ DecisionNode,
13
+ GraphBuilder,
14
+ GraphNode,
15
+ ParallelNode,
16
+ ResearchGraph,
17
+ SequentialEdge,
18
+ StateNode,
19
+ create_deep_graph,
20
+ create_iterative_graph,
21
+ )
22
+ from src.middleware.state_machine import WorkflowState
23
+
24
+
25
+ class TestGraphNode:
26
+ """Tests for GraphNode models."""
27
+
28
+ def test_graph_node_creation(self):
29
+ """Test creating a base GraphNode."""
30
+ node = GraphNode(node_id="test_node", node_type="agent", description="Test")
31
+ assert node.node_id == "test_node"
32
+ assert node.node_type == "agent"
33
+ assert node.description == "Test"
34
+
35
+ def test_agent_node_creation(self):
36
+ """Test creating an AgentNode."""
37
+ mock_agent = MagicMock(spec=Agent)
38
+ node = AgentNode(
39
+ node_id="agent_1",
40
+ agent=mock_agent,
41
+ description="Test agent",
42
+ )
43
+ assert node.node_id == "agent_1"
44
+ assert node.node_type == "agent"
45
+ assert node.agent == mock_agent
46
+ assert node.input_transformer is None
47
+ assert node.output_transformer is None
48
+
49
+ def test_agent_node_with_transformers(self):
50
+ """Test creating an AgentNode with transformers."""
51
+ mock_agent = MagicMock(spec=Agent)
52
+
53
+ def input_transformer(x):
54
+ return f"input_{x}"
55
+
56
+ def output_transformer(x):
57
+ return f"output_{x}"
58
+
59
+ node = AgentNode(
60
+ node_id="agent_1",
61
+ agent=mock_agent,
62
+ input_transformer=input_transformer,
63
+ output_transformer=output_transformer,
64
+ )
65
+ assert node.input_transformer is not None
66
+ assert node.output_transformer is not None
67
+
68
+ def test_state_node_creation(self):
69
+ """Test creating a StateNode."""
70
+
71
+ def state_updater(state: WorkflowState, data: Any) -> WorkflowState:
72
+ return state
73
+
74
+ node = StateNode(
75
+ node_id="state_1",
76
+ state_updater=state_updater,
77
+ description="Test state",
78
+ )
79
+ assert node.node_id == "state_1"
80
+ assert node.node_type == "state"
81
+ assert node.state_updater is not None
82
+ assert node.state_reader is None
83
+
84
+ def test_decision_node_creation(self):
85
+ """Test creating a DecisionNode."""
86
+
87
+ def decision_func(data: Any) -> str:
88
+ return "next_node"
89
+
90
+ node = DecisionNode(
91
+ node_id="decision_1",
92
+ decision_function=decision_func,
93
+ options=["next_node", "other_node"],
94
+ description="Test decision",
95
+ )
96
+ assert node.node_id == "decision_1"
97
+ assert node.node_type == "decision"
98
+ assert len(node.options) == 2
99
+ assert "next_node" in node.options
100
+
101
+ def test_parallel_node_creation(self):
102
+ """Test creating a ParallelNode."""
103
+ node = ParallelNode(
104
+ node_id="parallel_1",
105
+ parallel_nodes=["node1", "node2", "node3"],
106
+ description="Test parallel",
107
+ )
108
+ assert node.node_id == "parallel_1"
109
+ assert node.node_type == "parallel"
110
+ assert len(node.parallel_nodes) == 3
111
+ assert node.aggregator is None
112
+
113
+
114
+ class TestGraphEdge:
115
+ """Tests for GraphEdge models."""
116
+
117
+ def test_sequential_edge_creation(self):
118
+ """Test creating a SequentialEdge."""
119
+ edge = SequentialEdge(from_node="node1", to_node="node2")
120
+ assert edge.from_node == "node1"
121
+ assert edge.to_node == "node2"
122
+ assert edge.condition is None
123
+ assert edge.weight == 1.0
124
+
125
+ def test_conditional_edge_creation(self):
126
+ """Test creating a ConditionalEdge."""
127
+
128
+ def condition(data: Any) -> bool:
129
+ return True
130
+
131
+ edge = ConditionalEdge(
132
+ from_node="node1",
133
+ to_node="node2",
134
+ condition=condition,
135
+ condition_description="Test condition",
136
+ )
137
+ assert edge.from_node == "node1"
138
+ assert edge.to_node == "node2"
139
+ assert edge.condition is not None
140
+ assert edge.condition_description == "Test condition"
141
+
142
+
143
+ class TestResearchGraph:
144
+ """Tests for ResearchGraph class."""
145
+
146
+ def test_graph_creation(self):
147
+ """Test creating an empty graph."""
148
+ graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
149
+ assert graph.entry_node == "start"
150
+ assert len(graph.exit_nodes) == 1
151
+ assert graph.exit_nodes[0] == "end"
152
+ assert len(graph.nodes) == 0
153
+ assert len(graph.edges) == 0
154
+
155
+ def test_add_node(self):
156
+ """Test adding a node to the graph."""
157
+ graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
158
+ node = GraphNode(node_id="node1", node_type="agent", description="Test")
159
+ graph.add_node(node)
160
+ assert "node1" in graph.nodes
161
+ assert graph.get_node("node1") == node
162
+
163
+ def test_add_node_duplicate_raises_error(self):
164
+ """Test that adding duplicate node raises ValueError."""
165
+ graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
166
+ node = GraphNode(node_id="node1", node_type="agent", description="Test")
167
+ graph.add_node(node)
168
+ with pytest.raises(ValueError, match="already exists"):
169
+ graph.add_node(node)
170
+
171
+ def test_add_edge(self):
172
+ """Test adding an edge to the graph."""
173
+ graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
174
+ node1 = GraphNode(node_id="node1", node_type="agent", description="Test")
175
+ node2 = GraphNode(node_id="node2", node_type="agent", description="Test")
176
+ graph.add_node(node1)
177
+ graph.add_node(node2)
178
+
179
+ edge = SequentialEdge(from_node="node1", to_node="node2")
180
+ graph.add_edge(edge)
181
+ assert "node1" in graph.edges
182
+ assert len(graph.edges["node1"]) == 1
183
+ assert graph.edges["node1"][0] == edge
184
+
185
+ def test_add_edge_invalid_source_raises_error(self):
186
+ """Test that adding edge with invalid source raises ValueError."""
187
+ graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
188
+ edge = SequentialEdge(from_node="nonexistent", to_node="node2")
189
+ with pytest.raises(ValueError, match="Source node.*not found"):
190
+ graph.add_edge(edge)
191
+
192
+ def test_add_edge_invalid_target_raises_error(self):
193
+ """Test that adding edge with invalid target raises ValueError."""
194
+ graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
195
+ node1 = GraphNode(node_id="node1", node_type="agent", description="Test")
196
+ graph.add_node(node1)
197
+ edge = SequentialEdge(from_node="node1", to_node="nonexistent")
198
+ with pytest.raises(ValueError, match="Target node.*not found"):
199
+ graph.add_edge(edge)
200
+
201
+ def test_get_next_nodes(self):
202
+ """Test getting next nodes from a node."""
203
+ graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
204
+ node1 = GraphNode(node_id="node1", node_type="agent", description="Test")
205
+ node2 = GraphNode(node_id="node2", node_type="agent", description="Test")
206
+ graph.add_node(node1)
207
+ graph.add_node(node2)
208
+ graph.add_edge(SequentialEdge(from_node="node1", to_node="node2"))
209
+
210
+ next_nodes = graph.get_next_nodes("node1")
211
+ assert len(next_nodes) == 1
212
+ assert next_nodes[0][0] == "node2"
213
+
214
+ def test_get_next_nodes_with_condition(self):
215
+ """Test getting next nodes with conditional edge."""
216
+ graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
217
+ node1 = GraphNode(node_id="node1", node_type="agent", description="Test")
218
+ node2 = GraphNode(node_id="node2", node_type="agent", description="Test")
219
+ node3 = GraphNode(node_id="node3", node_type="agent", description="Test")
220
+ graph.add_node(node1)
221
+ graph.add_node(node2)
222
+ graph.add_node(node3)
223
+
224
+ # Add conditional edge that only passes when data is True
225
+ def condition(data: Any) -> bool:
226
+ return data is True
227
+
228
+ graph.add_edge(SequentialEdge(from_node="node1", to_node="node2"))
229
+ graph.add_edge(ConditionalEdge(from_node="node1", to_node="node3", condition=condition))
230
+
231
+ # With condition True, should get both
232
+ next_nodes = graph.get_next_nodes("node1", context=True)
233
+ assert len(next_nodes) == 2
234
+
235
+ # With condition False, should only get sequential edge
236
+ next_nodes = graph.get_next_nodes("node1", context=False)
237
+ assert len(next_nodes) == 1
238
+ assert next_nodes[0][0] == "node2"
239
+
240
+ def test_validate_empty_graph(self):
241
+ """Test validating an empty graph."""
242
+ graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
243
+ errors = graph.validate()
244
+ assert len(errors) > 0 # Should have errors for missing entry/exit nodes
245
+
246
+ def test_validate_valid_graph(self):
247
+ """Test validating a valid graph."""
248
+ graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
249
+ start_node = GraphNode(node_id="start", node_type="agent", description="Start")
250
+ end_node = GraphNode(node_id="end", node_type="agent", description="End")
251
+ graph.add_node(start_node)
252
+ graph.add_node(end_node)
253
+ graph.add_edge(SequentialEdge(from_node="start", to_node="end"))
254
+
255
+ errors = graph.validate()
256
+ assert len(errors) == 0
257
+
258
+ def test_validate_unreachable_nodes(self):
259
+ """Test that validation detects unreachable nodes."""
260
+ graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
261
+ start_node = GraphNode(node_id="start", node_type="agent", description="Start")
262
+ end_node = GraphNode(node_id="end", node_type="agent", description="End")
263
+ unreachable = GraphNode(node_id="unreachable", node_type="agent", description="Unreachable")
264
+ graph.add_node(start_node)
265
+ graph.add_node(end_node)
266
+ graph.add_node(unreachable)
267
+ graph.add_edge(SequentialEdge(from_node="start", to_node="end"))
268
+
269
+ errors = graph.validate()
270
+ assert len(errors) > 0
271
+ assert any("unreachable" in error.lower() for error in errors)
272
+
273
+
274
+ class TestGraphBuilder:
275
+ """Tests for GraphBuilder class."""
276
+
277
+ def test_builder_initialization(self):
278
+ """Test initializing a GraphBuilder."""
279
+ builder = GraphBuilder()
280
+ assert builder.graph is not None
281
+ assert builder.graph.entry_node == ""
282
+ assert len(builder.graph.exit_nodes) == 0
283
+
284
+ def test_add_agent_node(self):
285
+ """Test adding an agent node."""
286
+ builder = GraphBuilder()
287
+ mock_agent = MagicMock(spec=Agent)
288
+ builder.add_agent_node("agent1", mock_agent, "Test agent")
289
+ assert "agent1" in builder.graph.nodes
290
+ node = builder.graph.get_node("agent1")
291
+ assert isinstance(node, AgentNode)
292
+ assert node.agent == mock_agent
293
+
294
+ def test_add_state_node(self):
295
+ """Test adding a state node."""
296
+ builder = GraphBuilder()
297
+
298
+ def updater(state: WorkflowState, data: Any) -> WorkflowState:
299
+ return state
300
+
301
+ builder.add_state_node("state1", updater, "Test state")
302
+ assert "state1" in builder.graph.nodes
303
+ node = builder.graph.get_node("state1")
304
+ assert isinstance(node, StateNode)
305
+
306
+ def test_add_decision_node(self):
307
+ """Test adding a decision node."""
308
+ builder = GraphBuilder()
309
+
310
+ def decision_func(data: Any) -> str:
311
+ return "next"
312
+
313
+ builder.add_decision_node("decision1", decision_func, ["next", "other"], "Test")
314
+ assert "decision1" in builder.graph.nodes
315
+ node = builder.graph.get_node("decision1")
316
+ assert isinstance(node, DecisionNode)
317
+
318
+ def test_add_parallel_node(self):
319
+ """Test adding a parallel node."""
320
+ builder = GraphBuilder()
321
+ builder.add_parallel_node("parallel1", ["node1", "node2"], "Test")
322
+ assert "parallel1" in builder.graph.nodes
323
+ node = builder.graph.get_node("parallel1")
324
+ assert isinstance(node, ParallelNode)
325
+ assert len(node.parallel_nodes) == 2
326
+
327
+ def test_connect_nodes(self):
328
+ """Test connecting nodes."""
329
+ builder = GraphBuilder()
330
+ builder.add_agent_node("node1", MagicMock(spec=Agent), "Node 1")
331
+ builder.add_agent_node("node2", MagicMock(spec=Agent), "Node 2")
332
+ builder.connect_nodes("node1", "node2")
333
+ assert "node1" in builder.graph.edges
334
+ assert len(builder.graph.edges["node1"]) == 1
335
+
336
+ def test_connect_nodes_with_condition(self):
337
+ """Test connecting nodes with a condition."""
338
+ builder = GraphBuilder()
339
+ builder.add_agent_node("node1", MagicMock(spec=Agent), "Node 1")
340
+ builder.add_agent_node("node2", MagicMock(spec=Agent), "Node 2")
341
+
342
+ def condition(data: Any) -> bool:
343
+ return True
344
+
345
+ builder.connect_nodes("node1", "node2", condition=condition, condition_description="Test")
346
+ edge = builder.graph.edges["node1"][0]
347
+ assert isinstance(edge, ConditionalEdge)
348
+ assert edge.condition is not None
349
+
350
+ def test_set_entry_node(self):
351
+ """Test setting entry node."""
352
+ builder = GraphBuilder()
353
+ builder.add_agent_node("start", MagicMock(spec=Agent), "Start")
354
+ builder.set_entry_node("start")
355
+ assert builder.graph.entry_node == "start"
356
+
357
+ def test_set_exit_nodes(self):
358
+ """Test setting exit nodes."""
359
+ builder = GraphBuilder()
360
+ builder.add_agent_node("end1", MagicMock(spec=Agent), "End 1")
361
+ builder.add_agent_node("end2", MagicMock(spec=Agent), "End 2")
362
+ builder.set_exit_nodes(["end1", "end2"])
363
+ assert len(builder.graph.exit_nodes) == 2
364
+
365
+ def test_build_validates_graph(self):
366
+ """Test that build() validates the graph."""
367
+ builder = GraphBuilder()
368
+ builder.add_agent_node("start", MagicMock(spec=Agent), "Start")
369
+ builder.set_entry_node("start")
370
+ # Missing exit node - should fail validation
371
+ with pytest.raises(ValueError, match="validation failed"):
372
+ builder.build()
373
+
374
+ def test_build_returns_valid_graph(self):
375
+ """Test that build() returns a valid graph."""
376
+ builder = GraphBuilder()
377
+ mock_agent = MagicMock(spec=Agent)
378
+ builder.add_agent_node("start", mock_agent, "Start")
379
+ builder.add_agent_node("end", mock_agent, "End")
380
+ builder.connect_nodes("start", "end")
381
+ builder.set_entry_node("start")
382
+ builder.set_exit_nodes(["end"])
383
+
384
+ graph = builder.build()
385
+ assert isinstance(graph, ResearchGraph)
386
+ assert graph.entry_node == "start"
387
+ assert "end" in graph.exit_nodes
388
+
389
+
390
+ class TestFactoryFunctions:
391
+ """Tests for factory functions."""
392
+
393
+ def test_create_iterative_graph(self):
394
+ """Test creating an iterative research graph."""
395
+ mock_kg_agent = MagicMock(spec=Agent)
396
+ mock_ts_agent = MagicMock(spec=Agent)
397
+ mock_thinking_agent = MagicMock(spec=Agent)
398
+ mock_writer_agent = MagicMock(spec=Agent)
399
+
400
+ graph = create_iterative_graph(
401
+ knowledge_gap_agent=mock_kg_agent,
402
+ tool_selector_agent=mock_ts_agent,
403
+ thinking_agent=mock_thinking_agent,
404
+ writer_agent=mock_writer_agent,
405
+ )
406
+
407
+ assert isinstance(graph, ResearchGraph)
408
+ assert graph.entry_node == "thinking"
409
+ assert "writer" in graph.exit_nodes
410
+ assert "thinking" in graph.nodes
411
+ assert "knowledge_gap" in graph.nodes
412
+ assert "continue_decision" in graph.nodes
413
+ assert "tool_selector" in graph.nodes
414
+ assert "writer" in graph.nodes
415
+
416
+ def test_create_deep_graph(self):
417
+ """Test creating a deep research graph."""
418
+ mock_planner_agent = MagicMock(spec=Agent)
419
+ mock_kg_agent = MagicMock(spec=Agent)
420
+ mock_ts_agent = MagicMock(spec=Agent)
421
+ mock_thinking_agent = MagicMock(spec=Agent)
422
+ mock_writer_agent = MagicMock(spec=Agent)
423
+ mock_long_writer_agent = MagicMock(spec=Agent)
424
+
425
+ graph = create_deep_graph(
426
+ planner_agent=mock_planner_agent,
427
+ knowledge_gap_agent=mock_kg_agent,
428
+ tool_selector_agent=mock_ts_agent,
429
+ thinking_agent=mock_thinking_agent,
430
+ writer_agent=mock_writer_agent,
431
+ long_writer_agent=mock_long_writer_agent,
432
+ )
433
+
434
+ assert isinstance(graph, ResearchGraph)
435
+ assert graph.entry_node == "planner"
436
+ assert "synthesizer" in graph.exit_nodes
437
+ assert "planner" in graph.nodes
438
+ assert "parallel_loops_placeholder" in graph.nodes
439
+ assert "synthesizer" in graph.nodes
tests/unit/agents/test_input_parser.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for InputParserAgent."""
2
+
3
+ from unittest.mock import AsyncMock, MagicMock, patch
4
+
5
+ import pytest
6
+ from pydantic_ai import AgentRunResult
7
+
8
+ from src.agents.input_parser import InputParserAgent, create_input_parser_agent
9
+ from src.utils.exceptions import ConfigurationError
10
+ from src.utils.models import ParsedQuery
11
+
12
+
13
+ @pytest.fixture
14
+ def mock_model() -> MagicMock:
15
+ """Create a mock Pydantic AI model."""
16
+ model = MagicMock()
17
+ model.name = "test-model"
18
+ return model
19
+
20
+
21
+ @pytest.fixture
22
+ def mock_parsed_query_iterative() -> ParsedQuery:
23
+ """Create a mock ParsedQuery for iterative mode."""
24
+ return ParsedQuery(
25
+ original_query="What is the mechanism of metformin?",
26
+ improved_query="What is the molecular mechanism of action of metformin in diabetes treatment?",
27
+ research_mode="iterative",
28
+ key_entities=["metformin", "diabetes"],
29
+ research_questions=["What is metformin's mechanism of action?"],
30
+ )
31
+
32
+
33
+ @pytest.fixture
34
+ def mock_parsed_query_deep() -> ParsedQuery:
35
+ """Create a mock ParsedQuery for deep mode."""
36
+ return ParsedQuery(
37
+ original_query="Write a comprehensive report on diabetes treatment",
38
+ improved_query="Provide a comprehensive analysis of diabetes treatment options, including mechanisms, clinical evidence, and market analysis",
39
+ research_mode="deep",
40
+ key_entities=["diabetes", "treatment"],
41
+ research_questions=[
42
+ "What are the main treatment options for diabetes?",
43
+ "What is the clinical evidence for each treatment?",
44
+ "What is the market size for diabetes treatments?",
45
+ ],
46
+ )
47
+
48
+
49
+ @pytest.fixture
50
+ def mock_agent_result_iterative(
51
+ mock_parsed_query_iterative: ParsedQuery,
52
+ ) -> AgentRunResult[ParsedQuery]:
53
+ """Create a mock agent result for iterative mode."""
54
+ result = MagicMock(spec=AgentRunResult)
55
+ result.output = mock_parsed_query_iterative
56
+ return result
57
+
58
+
59
+ @pytest.fixture
60
+ def mock_agent_result_deep(
61
+ mock_parsed_query_deep: ParsedQuery,
62
+ ) -> AgentRunResult[ParsedQuery]:
63
+ """Create a mock agent result for deep mode."""
64
+ result = MagicMock(spec=AgentRunResult)
65
+ result.output = mock_parsed_query_deep
66
+ return result
67
+
68
+
69
+ @pytest.fixture
70
+ def input_parser_agent(mock_model: MagicMock) -> InputParserAgent:
71
+ """Create an InputParserAgent instance with mocked model."""
72
+ return InputParserAgent(model=mock_model)
73
+
74
+
75
+ class TestInputParserAgentInit:
76
+ """Test InputParserAgent initialization."""
77
+
78
+ def test_input_parser_agent_init_with_model(self, mock_model: MagicMock) -> None:
79
+ """Test InputParserAgent initialization with provided model."""
80
+ agent = InputParserAgent(model=mock_model)
81
+ assert agent.model == mock_model
82
+ assert agent.agent is not None
83
+
84
+ @patch("src.agents.input_parser.get_model")
85
+ def test_input_parser_agent_init_without_model(
86
+ self, mock_get_model: MagicMock, mock_model: MagicMock
87
+ ) -> None:
88
+ """Test InputParserAgent initialization without model (uses default)."""
89
+ mock_get_model.return_value = mock_model
90
+ agent = InputParserAgent()
91
+ assert agent.model == mock_model
92
+ mock_get_model.assert_called_once()
93
+
94
+ def test_input_parser_agent_has_correct_system_prompt(
95
+ self, input_parser_agent: InputParserAgent
96
+ ) -> None:
97
+ """Test that InputParserAgent has correct system prompt."""
98
+ # System prompt should contain key instructions
99
+ # In pydantic_ai, system_prompt is a property that returns the prompt string
100
+ # For mocked agents, we check that the agent was created with a system prompt
101
+ assert input_parser_agent.agent is not None
102
+ # The actual system prompt is set during agent creation
103
+ # We verify the agent exists and was properly initialized
104
+ # Note: Direct access to system_prompt may not work with mocks
105
+ # This test verifies the agent structure is correct
106
+
107
+
108
+ class TestParse:
109
+ """Test parse() method."""
110
+
111
+ @pytest.mark.asyncio
112
+ async def test_parse_iterative_query(
113
+ self,
114
+ input_parser_agent: InputParserAgent,
115
+ mock_agent_result_iterative: AgentRunResult[ParsedQuery],
116
+ ) -> None:
117
+ """Test parsing a simple query that should return iterative mode."""
118
+ input_parser_agent.agent.run = AsyncMock(return_value=mock_agent_result_iterative)
119
+
120
+ query = "What is the mechanism of metformin?"
121
+ result = await input_parser_agent.parse(query)
122
+
123
+ assert isinstance(result, ParsedQuery)
124
+ assert result.research_mode == "iterative"
125
+ assert result.original_query == query
126
+ assert "metformin" in result.key_entities
127
+ assert input_parser_agent.agent.run.called
128
+
129
+ @pytest.mark.asyncio
130
+ async def test_parse_deep_query(
131
+ self,
132
+ input_parser_agent: InputParserAgent,
133
+ mock_agent_result_deep: AgentRunResult[ParsedQuery],
134
+ ) -> None:
135
+ """Test parsing a complex query that should return deep mode."""
136
+ input_parser_agent.agent.run = AsyncMock(return_value=mock_agent_result_deep)
137
+
138
+ query = "Write a comprehensive report on diabetes treatment"
139
+ result = await input_parser_agent.parse(query)
140
+
141
+ assert isinstance(result, ParsedQuery)
142
+ assert result.research_mode == "deep"
143
+ assert result.original_query == query
144
+ assert len(result.research_questions) > 0
145
+ assert input_parser_agent.agent.run.called
146
+
147
+ @pytest.mark.asyncio
148
+ async def test_parse_improves_query(
149
+ self,
150
+ input_parser_agent: InputParserAgent,
151
+ mock_agent_result_iterative: AgentRunResult[ParsedQuery],
152
+ ) -> None:
153
+ """Test that parse() improves the query."""
154
+ input_parser_agent.agent.run = AsyncMock(return_value=mock_agent_result_iterative)
155
+
156
+ query = "metformin mechanism"
157
+ result = await input_parser_agent.parse(query)
158
+
159
+ assert isinstance(result, ParsedQuery)
160
+ assert result.improved_query != result.original_query
161
+ assert len(result.improved_query) >= len(result.original_query)
162
+
163
+ @pytest.mark.asyncio
164
+ async def test_parse_extracts_entities(
165
+ self,
166
+ input_parser_agent: InputParserAgent,
167
+ mock_agent_result_iterative: AgentRunResult[ParsedQuery],
168
+ ) -> None:
169
+ """Test that parse() extracts key entities."""
170
+ input_parser_agent.agent.run = AsyncMock(return_value=mock_agent_result_iterative)
171
+
172
+ query = "What is the mechanism of metformin?"
173
+ result = await input_parser_agent.parse(query)
174
+
175
+ assert isinstance(result, ParsedQuery)
176
+ assert len(result.key_entities) > 0
177
+ assert "metformin" in result.key_entities
178
+
179
+ @pytest.mark.asyncio
180
+ async def test_parse_extracts_research_questions(
181
+ self,
182
+ input_parser_agent: InputParserAgent,
183
+ mock_agent_result_deep: AgentRunResult[ParsedQuery],
184
+ ) -> None:
185
+ """Test that parse() extracts research questions."""
186
+ input_parser_agent.agent.run = AsyncMock(return_value=mock_agent_result_deep)
187
+
188
+ query = "Write a comprehensive report on diabetes treatment"
189
+ result = await input_parser_agent.parse(query)
190
+
191
+ assert isinstance(result, ParsedQuery)
192
+ assert len(result.research_questions) > 0
193
+
194
+ @pytest.mark.asyncio
195
+ async def test_parse_handles_missing_improved_query(
196
+ self,
197
+ input_parser_agent: InputParserAgent,
198
+ mock_model: MagicMock,
199
+ ) -> None:
200
+ """Test that parse() handles missing improved_query gracefully."""
201
+ # Create a result with missing improved_query
202
+ mock_result = MagicMock(spec=AgentRunResult)
203
+ mock_parsed = ParsedQuery(
204
+ original_query="test query",
205
+ improved_query="", # Empty improved query
206
+ research_mode="iterative",
207
+ key_entities=[],
208
+ research_questions=[],
209
+ )
210
+ mock_result.output = mock_parsed
211
+ input_parser_agent.agent.run = AsyncMock(return_value=mock_result)
212
+
213
+ query = "test query"
214
+ result = await input_parser_agent.parse(query)
215
+
216
+ # Should use original_query as fallback
217
+ assert isinstance(result, ParsedQuery)
218
+ assert result.improved_query == result.original_query
219
+
220
+ @pytest.mark.asyncio
221
+ async def test_parse_fallback_to_heuristic_on_error(
222
+ self, input_parser_agent: InputParserAgent
223
+ ) -> None:
224
+ """Test that parse() falls back to heuristic when agent fails."""
225
+ # Make agent.run raise an exception
226
+ input_parser_agent.agent.run = AsyncMock(side_effect=Exception("Agent failed"))
227
+
228
+ # Query with "comprehensive" should trigger deep mode heuristic
229
+ query = "Write a comprehensive report on diabetes"
230
+ result = await input_parser_agent.parse(query)
231
+
232
+ assert isinstance(result, ParsedQuery)
233
+ assert result.research_mode == "deep" # Heuristic should detect "comprehensive"
234
+ assert result.original_query == query
235
+ assert result.improved_query == query # No improvement on fallback
236
+
237
+ @pytest.mark.asyncio
238
+ async def test_parse_heuristic_iterative_mode(
239
+ self, input_parser_agent: InputParserAgent
240
+ ) -> None:
241
+ """Test that parse() heuristic correctly identifies iterative mode."""
242
+ # Make agent.run raise an exception
243
+ input_parser_agent.agent.run = AsyncMock(side_effect=Exception("Agent failed"))
244
+
245
+ # Simple query should trigger iterative mode heuristic
246
+ query = "What is metformin?"
247
+ result = await input_parser_agent.parse(query)
248
+
249
+ assert isinstance(result, ParsedQuery)
250
+ assert result.research_mode == "iterative"
251
+ assert result.original_query == query
252
+
253
+
254
+ class TestCreateInputParserAgent:
255
+ """Test create_input_parser_agent() factory function."""
256
+
257
+ @patch("src.agents.input_parser.get_model")
258
+ def test_create_input_parser_agent_with_model(
259
+ self, mock_get_model: MagicMock, mock_model: MagicMock
260
+ ) -> None:
261
+ """Test factory function with provided model."""
262
+ agent = create_input_parser_agent(model=mock_model)
263
+ assert isinstance(agent, InputParserAgent)
264
+ assert agent.model == mock_model
265
+ mock_get_model.assert_not_called()
266
+
267
+ @patch("src.agents.input_parser.get_model")
268
+ def test_create_input_parser_agent_without_model(
269
+ self, mock_get_model: MagicMock, mock_model: MagicMock
270
+ ) -> None:
271
+ """Test factory function without model (uses default)."""
272
+ mock_get_model.return_value = mock_model
273
+ agent = create_input_parser_agent()
274
+ assert isinstance(agent, InputParserAgent)
275
+ assert agent.model == mock_model
276
+ mock_get_model.assert_called_once()
277
+
278
+ @patch("src.agents.input_parser.get_model")
279
+ def test_create_input_parser_agent_handles_error(self, mock_get_model: MagicMock) -> None:
280
+ """Test factory function handles errors gracefully."""
281
+ mock_get_model.side_effect = Exception("Model creation failed")
282
+ with pytest.raises(ConfigurationError, match="Failed to create input parser agent"):
283
+ create_input_parser_agent()
284
+
285
+
286
+ class TestResearchModeDetection:
287
+ """Test research mode detection logic."""
288
+
289
+ @pytest.mark.asyncio
290
+ async def test_detects_iterative_mode_for_simple_queries(
291
+ self,
292
+ input_parser_agent: InputParserAgent,
293
+ mock_agent_result_iterative: AgentRunResult[ParsedQuery],
294
+ ) -> None:
295
+ """Test that simple queries are detected as iterative."""
296
+ input_parser_agent.agent.run = AsyncMock(return_value=mock_agent_result_iterative)
297
+
298
+ simple_queries = [
299
+ "What is the mechanism of metformin?",
300
+ "Find clinical trials for drug X",
301
+ "What is the capital of France?",
302
+ ]
303
+
304
+ for query in simple_queries:
305
+ result = await input_parser_agent.parse(query)
306
+ assert result.research_mode == "iterative", f"Query '{query}' should be iterative"
307
+
308
+ @pytest.mark.asyncio
309
+ async def test_detects_deep_mode_for_complex_queries(
310
+ self,
311
+ input_parser_agent: InputParserAgent,
312
+ mock_agent_result_deep: AgentRunResult[ParsedQuery],
313
+ ) -> None:
314
+ """Test that complex queries are detected as deep."""
315
+ input_parser_agent.agent.run = AsyncMock(return_value=mock_agent_result_deep)
316
+
317
+ complex_queries = [
318
+ "Write a comprehensive report on diabetes treatment",
319
+ "Analyze the market for quantum computing",
320
+ "Provide a detailed analysis of AI trends",
321
+ ]
322
+
323
+ for query in complex_queries:
324
+ result = await input_parser_agent.parse(query)
325
+ assert result.research_mode == "deep", f"Query '{query}' should be deep"
tests/unit/agents/test_long_writer.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for LongWriterAgent."""
2
+
3
+ from unittest.mock import AsyncMock, MagicMock, patch
4
+
5
+ import pytest
6
+ from pydantic_ai import AgentResult
7
+
8
+ from src.agents.long_writer import LongWriterAgent, LongWriterOutput, create_long_writer_agent
9
+ from src.utils.models import ReportDraft, ReportDraftSection
10
+
11
+
12
+ @pytest.fixture
13
+ def mock_model() -> MagicMock:
14
+ """Create a mock Pydantic AI model."""
15
+ model = MagicMock()
16
+ model.name = "test-model"
17
+ return model
18
+
19
+
20
+ @pytest.fixture
21
+ def mock_long_writer_output() -> LongWriterOutput:
22
+ """Create a mock LongWriterOutput."""
23
+ return LongWriterOutput(
24
+ next_section_markdown="## Test Section\n\nContent with citation [1].",
25
+ references=["[1] https://example.com"],
26
+ )
27
+
28
+
29
+ @pytest.fixture
30
+ def mock_agent_result(mock_long_writer_output: LongWriterOutput) -> AgentResult[LongWriterOutput]:
31
+ """Create a mock agent result."""
32
+ result = MagicMock(spec=AgentResult)
33
+ result.output = mock_long_writer_output
34
+ return result
35
+
36
+
37
+ @pytest.fixture
38
+ def long_writer_agent(mock_model: MagicMock) -> LongWriterAgent:
39
+ """Create a LongWriterAgent instance with mocked model."""
40
+ return LongWriterAgent(model=mock_model)
41
+
42
+
43
+ @pytest.fixture
44
+ def sample_report_draft() -> ReportDraft:
45
+ """Create a sample ReportDraft for testing."""
46
+ return ReportDraft(
47
+ sections=[
48
+ ReportDraftSection(
49
+ section_title="Introduction",
50
+ section_content="Introduction content with [1].",
51
+ ),
52
+ ReportDraftSection(
53
+ section_title="Methods",
54
+ section_content="Methods content with [2].",
55
+ ),
56
+ ]
57
+ )
58
+
59
+
60
+ class TestLongWriterAgentInit:
61
+ """Test LongWriterAgent initialization."""
62
+
63
+ def test_long_writer_agent_init_with_model(self, mock_model: MagicMock) -> None:
64
+ """Test LongWriterAgent initialization with provided model."""
65
+ agent = LongWriterAgent(model=mock_model)
66
+ assert agent.model == mock_model
67
+ assert agent.agent is not None
68
+
69
+ @patch("src.agents.long_writer.get_model")
70
+ def test_long_writer_agent_init_without_model(
71
+ self, mock_get_model: MagicMock, mock_model: MagicMock
72
+ ) -> None:
73
+ """Test LongWriterAgent initialization without model (uses default)."""
74
+ mock_get_model.return_value = mock_model
75
+ agent = LongWriterAgent()
76
+ assert agent.model == mock_model
77
+ mock_get_model.assert_called_once()
78
+
79
+ def test_long_writer_agent_has_structured_output(
80
+ self, long_writer_agent: LongWriterAgent
81
+ ) -> None:
82
+ """Test that LongWriterAgent uses structured output."""
83
+ assert long_writer_agent.agent.output_type == LongWriterOutput
84
+
85
+
86
+ class TestWriteNextSection:
87
+ """Test write_next_section() method."""
88
+
89
+ @pytest.mark.asyncio
90
+ async def test_write_next_section_basic(
91
+ self,
92
+ long_writer_agent: LongWriterAgent,
93
+ mock_agent_result: AgentResult[LongWriterOutput],
94
+ ) -> None:
95
+ """Test basic section writing."""
96
+ long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
97
+
98
+ original_query = "Test query"
99
+ report_draft = "## Existing Section\n\nContent"
100
+ next_section_title = "New Section"
101
+ next_section_draft = "Draft content"
102
+
103
+ result = await long_writer_agent.write_next_section(
104
+ original_query=original_query,
105
+ report_draft=report_draft,
106
+ next_section_title=next_section_title,
107
+ next_section_draft=next_section_draft,
108
+ )
109
+
110
+ assert isinstance(result, LongWriterOutput)
111
+ assert result.next_section_markdown is not None
112
+ assert isinstance(result.references, list)
113
+ assert long_writer_agent.agent.run.called
114
+
115
+ @pytest.mark.asyncio
116
+ async def test_write_next_section_first_section(
117
+ self,
118
+ long_writer_agent: LongWriterAgent,
119
+ mock_agent_result: AgentResult[LongWriterOutput],
120
+ ) -> None:
121
+ """Test writing the first section (no existing draft)."""
122
+ long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
123
+
124
+ original_query = "Test query"
125
+ report_draft = "" # No existing draft
126
+ next_section_title = "First Section"
127
+ next_section_draft = "Draft content"
128
+
129
+ result = await long_writer_agent.write_next_section(
130
+ original_query=original_query,
131
+ report_draft=report_draft,
132
+ next_section_title=next_section_title,
133
+ next_section_draft=next_section_draft,
134
+ )
135
+
136
+ assert isinstance(result, LongWriterOutput)
137
+ # Check that "No draft yet" was included in prompt
138
+ call_args = long_writer_agent.agent.run.call_args[0][0]
139
+ assert "No draft yet" in call_args or report_draft in call_args
140
+
141
+ @pytest.mark.asyncio
142
+ async def test_write_next_section_with_existing_draft(
143
+ self,
144
+ long_writer_agent: LongWriterAgent,
145
+ mock_agent_result: AgentResult[LongWriterOutput],
146
+ ) -> None:
147
+ """Test writing section with existing draft."""
148
+ long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
149
+
150
+ original_query = "Test query"
151
+ report_draft = "## Previous Section\n\nPrevious content"
152
+ next_section_title = "Next Section"
153
+ next_section_draft = "Next draft"
154
+
155
+ result = await long_writer_agent.write_next_section(
156
+ original_query=original_query,
157
+ report_draft=report_draft,
158
+ next_section_title=next_section_title,
159
+ next_section_draft=next_section_draft,
160
+ )
161
+
162
+ assert isinstance(result, LongWriterOutput)
163
+ # Check that existing draft was included in prompt
164
+ call_args = long_writer_agent.agent.run.call_args[0][0]
165
+ assert "Previous Section" in call_args
166
+
167
+ @pytest.mark.asyncio
168
+ async def test_write_next_section_returns_references(
169
+ self,
170
+ long_writer_agent: LongWriterAgent,
171
+ mock_agent_result: AgentResult[LongWriterOutput],
172
+ ) -> None:
173
+ """Test that write_next_section returns references."""
174
+ long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
175
+
176
+ result = await long_writer_agent.write_next_section(
177
+ original_query="Test",
178
+ report_draft="",
179
+ next_section_title="Test",
180
+ next_section_draft="Test",
181
+ )
182
+
183
+ assert isinstance(result.references, list)
184
+ assert len(result.references) > 0
185
+
186
+ @pytest.mark.asyncio
187
+ async def test_write_next_section_handles_empty_draft(
188
+ self,
189
+ long_writer_agent: LongWriterAgent,
190
+ mock_agent_result: AgentResult[LongWriterOutput],
191
+ ) -> None:
192
+ """Test writing section with empty draft."""
193
+ long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
194
+
195
+ result = await long_writer_agent.write_next_section(
196
+ original_query="Test",
197
+ report_draft="",
198
+ next_section_title="Test",
199
+ next_section_draft="",
200
+ )
201
+
202
+ assert isinstance(result, LongWriterOutput)
203
+
204
+ @pytest.mark.asyncio
205
+ async def test_write_next_section_llm_failure(self, long_writer_agent: LongWriterAgent) -> None:
206
+ """Test write_next_section handles LLM failures gracefully."""
207
+ long_writer_agent.agent.run = AsyncMock(side_effect=Exception("LLM error"))
208
+
209
+ result = await long_writer_agent.write_next_section(
210
+ original_query="Test",
211
+ report_draft="",
212
+ next_section_title="Test",
213
+ next_section_draft="Test",
214
+ )
215
+
216
+ # Should return fallback section
217
+ assert isinstance(result, LongWriterOutput)
218
+ assert "Test" in result.next_section_markdown
219
+ assert result.references == []
220
+
221
+
222
+ class TestWriteReport:
223
+ """Test write_report() method."""
224
+
225
+ @pytest.mark.asyncio
226
+ async def test_write_report_complete_flow(
227
+ self,
228
+ long_writer_agent: LongWriterAgent,
229
+ mock_agent_result: AgentResult[LongWriterOutput],
230
+ sample_report_draft: ReportDraft,
231
+ ) -> None:
232
+ """Test complete report writing flow."""
233
+ long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
234
+
235
+ original_query = "Test query"
236
+ report_title = "Test Report"
237
+
238
+ result = await long_writer_agent.write_report(
239
+ original_query=original_query,
240
+ report_title=report_title,
241
+ report_draft=sample_report_draft,
242
+ )
243
+
244
+ assert isinstance(result, str)
245
+ assert report_title in result
246
+ assert "Table of Contents" in result
247
+ assert "Introduction" in result
248
+ assert "Methods" in result
249
+ # Should have called write_next_section for each section
250
+ assert long_writer_agent.agent.run.call_count == len(sample_report_draft.sections)
251
+
252
+ @pytest.mark.asyncio
253
+ async def test_write_report_single_section(
254
+ self,
255
+ long_writer_agent: LongWriterAgent,
256
+ mock_agent_result: AgentResult[LongWriterOutput],
257
+ ) -> None:
258
+ """Test writing report with single section."""
259
+ long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
260
+
261
+ report_draft = ReportDraft(
262
+ sections=[
263
+ ReportDraftSection(
264
+ section_title="Single Section",
265
+ section_content="Content",
266
+ )
267
+ ]
268
+ )
269
+
270
+ result = await long_writer_agent.write_report(
271
+ original_query="Test",
272
+ report_title="Test Report",
273
+ report_draft=report_draft,
274
+ )
275
+
276
+ assert isinstance(result, str)
277
+ assert "Single Section" in result
278
+ assert long_writer_agent.agent.run.call_count == 1
279
+
280
+ @pytest.mark.asyncio
281
+ async def test_write_report_multiple_sections(
282
+ self,
283
+ long_writer_agent: LongWriterAgent,
284
+ mock_agent_result: AgentResult[LongWriterOutput],
285
+ sample_report_draft: ReportDraft,
286
+ ) -> None:
287
+ """Test writing report with multiple sections."""
288
+ long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
289
+
290
+ result = await long_writer_agent.write_report(
291
+ original_query="Test",
292
+ report_title="Test Report",
293
+ report_draft=sample_report_draft,
294
+ )
295
+
296
+ assert isinstance(result, str)
297
+ assert sample_report_draft.sections[0].section_title in result
298
+ assert sample_report_draft.sections[1].section_title in result
299
+ assert long_writer_agent.agent.run.call_count == len(sample_report_draft.sections)
300
+
301
+ @pytest.mark.asyncio
302
+ async def test_write_report_creates_table_of_contents(
303
+ self,
304
+ long_writer_agent: LongWriterAgent,
305
+ mock_agent_result: AgentResult[LongWriterOutput],
306
+ sample_report_draft: ReportDraft,
307
+ ) -> None:
308
+ """Test that write_report creates table of contents."""
309
+ long_writer_agent.agent.run = AsyncMock(return_value=mock_agent_result)
310
+
311
+ result = await long_writer_agent.write_report(
312
+ original_query="Test",
313
+ report_title="Test Report",
314
+ report_draft=sample_report_draft,
315
+ )
316
+
317
+ assert "Table of Contents" in result
318
+ assert "1. Introduction" in result
319
+ assert "2. Methods" in result
320
+
321
+ @pytest.mark.asyncio
322
+ async def test_write_report_aggregates_references(
323
+ self,
324
+ long_writer_agent: LongWriterAgent,
325
+ sample_report_draft: ReportDraft,
326
+ ) -> None:
327
+ """Test that write_report aggregates references from all sections."""
328
+ # Create different outputs for each section
329
+ output1 = LongWriterOutput(
330
+ next_section_markdown="## Introduction\n\nContent [1].",
331
+ references=["[1] https://example.com/1"],
332
+ )
333
+ output2 = LongWriterOutput(
334
+ next_section_markdown="## Methods\n\nContent [1].",
335
+ references=["[1] https://example.com/2"],
336
+ )
337
+
338
+ results = [AgentResult(output=output1), AgentResult(output=output2)]
339
+ long_writer_agent.agent.run = AsyncMock(side_effect=results)
340
+
341
+ result = await long_writer_agent.write_report(
342
+ original_query="Test",
343
+ report_title="Test Report",
344
+ report_draft=sample_report_draft,
345
+ )
346
+
347
+ assert "References:" in result
348
+ # Should have both references (reformatted)
349
+ assert "example.com/1" in result or "[1]" in result
350
+ assert "example.com/2" in result or "[2]" in result
351
+
352
+
353
+ class TestReformatReferences:
354
+ """Test _reformat_references() method."""
355
+
356
+ def test_reformat_references_deduplicates(self, long_writer_agent: LongWriterAgent) -> None:
357
+ """Test that reference reformatting deduplicates URLs."""
358
+ section_markdown = "Content [1] and [2]."
359
+ section_references = [
360
+ "[1] https://example.com",
361
+ "[2] https://example.com", # Duplicate URL
362
+ ]
363
+ all_references = []
364
+
365
+ updated_markdown, updated_refs = long_writer_agent._reformat_references(
366
+ section_markdown, section_references, all_references
367
+ )
368
+
369
+ # Should only have one reference
370
+ assert len(updated_refs) == 1
371
+ assert "example.com" in updated_refs[0]
372
+
373
+ def test_reformat_references_renumbers(self, long_writer_agent: LongWriterAgent) -> None:
374
+ """Test that reference reformatting renumbers correctly."""
375
+ section_markdown = "Content [1] and [2]."
376
+ section_references = [
377
+ "[1] https://example.com/1",
378
+ "[2] https://example.com/2",
379
+ ]
380
+ all_references = ["[1] https://example.com/0"] # Existing reference
381
+
382
+ updated_markdown, updated_refs = long_writer_agent._reformat_references(
383
+ section_markdown, section_references, all_references
384
+ )
385
+
386
+ # Should have 3 references total (0, 1, 2)
387
+ assert len(updated_refs) == 3
388
+ # Markdown should have updated reference numbers
389
+ assert "[2]" in updated_markdown or "[3]" in updated_markdown
390
+
391
+ def test_reformat_references_handles_malformed(
392
+ self, long_writer_agent: LongWriterAgent
393
+ ) -> None:
394
+ """Test that reference reformatting handles malformed references."""
395
+ section_markdown = "Content [1]."
396
+ section_references = [
397
+ "[1] https://example.com",
398
+ "invalid reference", # Malformed
399
+ ]
400
+ all_references = []
401
+
402
+ updated_markdown, updated_refs = long_writer_agent._reformat_references(
403
+ section_markdown, section_references, all_references
404
+ )
405
+
406
+ # Should still work, just skip invalid references
407
+ assert isinstance(updated_markdown, str)
408
+ assert isinstance(updated_refs, list)
409
+
410
+ def test_reformat_references_empty_list(self, long_writer_agent: LongWriterAgent) -> None:
411
+ """Test reference reformatting with empty reference list."""
412
+ section_markdown = "Content without citations."
413
+ section_references = []
414
+ all_references = []
415
+
416
+ updated_markdown, updated_refs = long_writer_agent._reformat_references(
417
+ section_markdown, section_references, all_references
418
+ )
419
+
420
+ assert updated_markdown == section_markdown
421
+ assert updated_refs == []
422
+
423
+ def test_reformat_references_preserves_markdown(
424
+ self, long_writer_agent: LongWriterAgent
425
+ ) -> None:
426
+ """Test that reference reformatting preserves markdown content."""
427
+ section_markdown = "## Section\n\nContent [1] with **bold** text."
428
+ section_references = ["[1] https://example.com"]
429
+ all_references = []
430
+
431
+ updated_markdown, _ = long_writer_agent._reformat_references(
432
+ section_markdown, section_references, all_references
433
+ )
434
+
435
+ assert "## Section" in updated_markdown
436
+ assert "**bold**" in updated_markdown
437
+
438
+
439
+ class TestReformatSectionHeadings:
440
+ """Test _reformat_section_headings() method."""
441
+
442
+ def test_reformat_section_headings_level_2(self, long_writer_agent: LongWriterAgent) -> None:
443
+ """Test that headings are reformatted to level 2."""
444
+ section_markdown = "## Section Title\n\nContent"
445
+ result = long_writer_agent._reformat_section_headings(section_markdown)
446
+ assert "## Section Title" in result
447
+
448
+ def test_reformat_section_headings_level_3(self, long_writer_agent: LongWriterAgent) -> None:
449
+ """Test that level 3 headings are adjusted correctly."""
450
+ section_markdown = "### Section Title\n\nContent"
451
+ result = long_writer_agent._reformat_section_headings(section_markdown)
452
+ # Should be adjusted to level 2
453
+ assert "## Section Title" in result
454
+
455
+ def test_reformat_section_headings_no_headings(
456
+ self, long_writer_agent: LongWriterAgent
457
+ ) -> None:
458
+ """Test reformatting with no headings."""
459
+ section_markdown = "Just content without headings."
460
+ result = long_writer_agent._reformat_section_headings(section_markdown)
461
+ assert result == section_markdown
462
+
463
+ def test_reformat_section_headings_preserves_content(
464
+ self, long_writer_agent: LongWriterAgent
465
+ ) -> None:
466
+ """Test that content is preserved during heading reformatting."""
467
+ section_markdown = "# Section\n\nImportant content here."
468
+ result = long_writer_agent._reformat_section_headings(section_markdown)
469
+ assert "Important content here" in result
470
+
471
+
472
+ class TestCreateLongWriterAgent:
473
+ """Test create_long_writer_agent factory function."""
474
+
475
+ @patch("src.agents.long_writer.get_model")
476
+ @patch("src.agents.long_writer.LongWriterAgent")
477
+ def test_create_long_writer_agent_success(
478
+ self,
479
+ mock_long_writer_agent_class: MagicMock,
480
+ mock_get_model: MagicMock,
481
+ mock_model: MagicMock,
482
+ ) -> None:
483
+ """Test successful long writer agent creation."""
484
+ mock_get_model.return_value = mock_model
485
+ mock_agent_instance = MagicMock()
486
+ mock_long_writer_agent_class.return_value = mock_agent_instance
487
+
488
+ result = create_long_writer_agent()
489
+
490
+ assert result == mock_agent_instance
491
+ mock_long_writer_agent_class.assert_called_once_with(model=mock_model)
492
+
493
+ @patch("src.agents.long_writer.get_model")
494
+ @patch("src.agents.long_writer.LongWriterAgent")
495
+ def test_create_long_writer_agent_with_custom_model(
496
+ self,
497
+ mock_long_writer_agent_class: MagicMock,
498
+ mock_get_model: MagicMock,
499
+ mock_model: MagicMock,
500
+ ) -> None:
501
+ """Test long writer agent creation with custom model."""
502
+ mock_agent_instance = MagicMock()
503
+ mock_long_writer_agent_class.return_value = mock_agent_instance
504
+
505
+ result = create_long_writer_agent(model=mock_model)
506
+
507
+ assert result == mock_agent_instance
508
+ mock_long_writer_agent_class.assert_called_once_with(model=mock_model)
509
+ mock_get_model.assert_not_called()