Joseph Pollack commited on
Commit
687a1f1
·
unverified ·
1 Parent(s): 731a241

adds the initial iterative and deep research workflows

Browse files
src/middleware/__init__.py CHANGED
@@ -19,15 +19,12 @@ from src.middleware.workflow_manager import (
19
  )
20
 
21
  __all__ = [
22
- # State management
 
 
 
 
23
  "WorkflowState",
24
- "init_workflow_state",
25
  "get_workflow_state",
26
- # Workflow management
27
- "WorkflowManager",
28
- "ResearchLoop",
29
- "LoopStatus",
30
- # Budget tracking
31
- "BudgetTracker",
32
- "BudgetStatus",
33
  ]
 
19
  )
20
 
21
  __all__ = [
22
+ "BudgetStatus",
23
+ "BudgetTracker",
24
+ "LoopStatus",
25
+ "ResearchLoop",
26
+ "WorkflowManager",
27
  "WorkflowState",
 
28
  "get_workflow_state",
29
+ "init_workflow_state",
 
 
 
 
 
 
30
  ]
src/orchestrator/__init__.py CHANGED
@@ -36,13 +36,13 @@ from src.orchestrator.planner_agent import PlannerAgent, create_planner_agent
36
  from src.orchestrator.research_flow import DeepResearchFlow, IterativeResearchFlow
37
 
38
  __all__ = [
39
- "PlannerAgent",
40
- "create_planner_agent",
41
- "IterativeResearchFlow",
42
  "DeepResearchFlow",
43
  "GraphOrchestrator",
44
- "create_graph_orchestrator",
45
- "SearchHandlerProtocol",
46
  "JudgeHandlerProtocol",
47
  "Orchestrator",
 
 
 
 
48
  ]
 
36
  from src.orchestrator.research_flow import DeepResearchFlow, IterativeResearchFlow
37
 
38
  __all__ = [
 
 
 
39
  "DeepResearchFlow",
40
  "GraphOrchestrator",
41
+ "IterativeResearchFlow",
 
42
  "JudgeHandlerProtocol",
43
  "Orchestrator",
44
+ "PlannerAgent",
45
+ "SearchHandlerProtocol",
46
+ "create_graph_orchestrator",
47
+ "create_planner_agent",
48
  ]
src/orchestrator_factory.py CHANGED
@@ -2,7 +2,11 @@
2
 
3
  from typing import Any, Literal
4
 
5
- from src.legacy_orchestrator import JudgeHandlerProtocol, Orchestrator, SearchHandlerProtocol
 
 
 
 
6
  from src.utils.models import OrchestratorConfig
7
 
8
 
 
2
 
3
  from typing import Any, Literal
4
 
5
+ from src.legacy_orchestrator import (
6
+ JudgeHandlerProtocol,
7
+ Orchestrator,
8
+ SearchHandlerProtocol,
9
+ )
10
  from src.utils.models import OrchestratorConfig
11
 
12
 
src/tools/__init__.py CHANGED
@@ -7,9 +7,9 @@ from src.tools.search_handler import SearchHandler
7
 
8
  # Re-export
9
  __all__ = [
 
10
  "PubMedTool",
11
  "SearchHandler",
12
  "SearchTool",
13
- "RAGTool",
14
  "create_rag_tool",
15
  ]
 
7
 
8
  # Re-export
9
  __all__ = [
10
+ "RAGTool",
11
  "PubMedTool",
12
  "SearchHandler",
13
  "SearchTool",
 
14
  "create_rag_tool",
15
  ]
src/tools/search_handler.py CHANGED
@@ -12,6 +12,8 @@ from src.utils.models import Evidence, SearchResult, SourceName
12
 
13
  if TYPE_CHECKING:
14
  from src.services.llamaindex_rag import LlamaIndexRAGService
 
 
15
 
16
  logger = structlog.get_logger()
17
 
@@ -38,7 +40,7 @@ class SearchHandler:
38
  self.tools = list(tools) # Make a copy
39
  self.timeout = timeout
40
  self.auto_ingest_to_rag = auto_ingest_to_rag
41
- self._rag_service: "LlamaIndexRAGService | None" = None
42
 
43
  if include_rag:
44
  self.add_rag_tool()
 
12
 
13
  if TYPE_CHECKING:
14
  from src.services.llamaindex_rag import LlamaIndexRAGService
15
+ else:
16
+ LlamaIndexRAGService = object
17
 
18
  logger = structlog.get_logger()
19
 
 
40
  self.tools = list(tools) # Make a copy
41
  self.timeout = timeout
42
  self.auto_ingest_to_rag = auto_ingest_to_rag
43
+ self._rag_service: LlamaIndexRAGService | None = None
44
 
45
  if include_rag:
46
  self.add_rag_tool()
tests/unit/agent_factory/test_graph_builder.py CHANGED
@@ -186,7 +186,7 @@ class TestResearchGraph:
186
  """Test that adding edge with invalid source raises ValueError."""
187
  graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
188
  edge = SequentialEdge(from_node="nonexistent", to_node="node2")
189
- with pytest.raises(ValueError, match="Source node.*not found"):
190
  graph.add_edge(edge)
191
 
192
  def test_add_edge_invalid_target_raises_error(self):
@@ -195,7 +195,7 @@ class TestResearchGraph:
195
  node1 = GraphNode(node_id="node1", node_type="agent", description="Test")
196
  graph.add_node(node1)
197
  edge = SequentialEdge(from_node="node1", to_node="nonexistent")
198
- with pytest.raises(ValueError, match="Target node.*not found"):
199
  graph.add_edge(edge)
200
 
201
  def test_get_next_nodes(self):
 
186
  """Test that adding edge with invalid source raises ValueError."""
187
  graph = ResearchGraph(entry_node="start", exit_nodes=["end"])
188
  edge = SequentialEdge(from_node="nonexistent", to_node="node2")
189
+ with pytest.raises(ValueError, match=r"Source node.*not found"):
190
  graph.add_edge(edge)
191
 
192
  def test_add_edge_invalid_target_raises_error(self):
 
195
  node1 = GraphNode(node_id="node1", node_type="agent", description="Test")
196
  graph.add_node(node1)
197
  edge = SequentialEdge(from_node="node1", to_node="nonexistent")
198
+ with pytest.raises(ValueError, match=r"Target node.*not found"):
199
  graph.add_edge(edge)
200
 
201
  def test_get_next_nodes(self):
tests/unit/agents/test_long_writer.py CHANGED
@@ -362,7 +362,7 @@ class TestReformatReferences:
362
  ]
363
  all_references = []
364
 
365
- updated_markdown, updated_refs = long_writer_agent._reformat_references(
366
  section_markdown, section_references, all_references
367
  )
368
 
 
362
  ]
363
  all_references = []
364
 
365
+ _updated_markdown, updated_refs = long_writer_agent._reformat_references(
366
  section_markdown, section_references, all_references
367
  )
368