File size: 5,660 Bytes
cf02b2b
 
 
b971859
 
cf02b2b
 
 
 
 
 
 
b971859
 
cf02b2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from langgraph.graph import StateGraph, END, START
from langchain_core.messages import SystemMessage
from typing import TypedDict
# Temporarily disabled due to protobuf issues
# from config import GOOGLE_API_KEY, GEMINI_MODEL, GEMINI_TEMPERATURE
from rag_service import search_docs, search_government_docs, analyze_scenario
from langchain_tavily import TavilySearch
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.memory import MemorySaver
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages
# Temporarily disabled due to protobuf issues
# from langchain_google_genai import ChatGoogleGenerativeAI
from pydantic import BaseModel, Field, field_validator
from typing import List, Optional, Literal
from langchain_core.tools import tool
import asyncio

# Optional MCP client - install with: pip install langchain-mcp-adapters
try:
    from langchain_mcp_adapters.client import MultiServerMCPClient
    MCP_AVAILABLE = True
except ImportError:
    MCP_AVAILABLE = False

# Optional Tavily search - requires TAVILY_API_KEY environment variable
try:
    tavily_search = TavilySearch(max_results=4)
    TAVILY_AVAILABLE = True
except Exception:
    TAVILY_AVAILABLE = False


@tool
def search_tool(query: str):
    """
    Perform an advanced web search using the Tavily Search API with hardcoded options.

    Parameters:
    ----------
    query : str
        The search query string.

    Returns:
    -------
    str
        The search results as a string returned by the Tavily Search API.

    Raises:
    ------
    Exception
        Errors during the search are caught and returned as error strings.
    """
    if not TAVILY_AVAILABLE:
        return "Web search is not available. Tavily API key is not configured."
    
    query_params = {"query": query, "auto_parameters": True}

    try:
        result = tavily_search.invoke(query_params)
        return result
    except Exception as e:
        return f"Error during Tavily search: {str(e)}"


# State definition
class State(TypedDict):
    # add_messages is known as a reducer, where it does not modify the list but adds messages to it
    messages: Annotated[list, add_messages]
    # messages: Annotated[list[BaseMessage], add_messages]
    # both have same result no need to use BaseMessage


async def create_graph(kb_tool: bool, mcp_config: dict):
    if mcp_config and MCP_AVAILABLE:
        server_config = {
            "url": mcp_config["url"],
            "transport": "streamable_http",
        }

        # Add headers if bearer token exists
        if mcp_config.get("bearerToken"):
            server_config["headers"] = {
                "Authorization": f"Bearer {mcp_config['bearerToken']}"
            }

        client = MultiServerMCPClient({mcp_config["name"]: server_config})
        mcp_tools = await client.get_tools()
    else:
        mcp_tools = []
    llm = ChatGoogleGenerativeAI(
        model=GEMINI_MODEL,
        google_api_key=GOOGLE_API_KEY,
        temperature=GEMINI_TEMPERATURE,
    )
    if kb_tool:
        tools = [search_docs, search_government_docs, analyze_scenario, search_tool]
    else:
        tools = [search_tool, analyze_scenario]
    tools = tools + mcp_tools
    llm_with_tools = llm.bind_tools(tools)

    async def llm_node(state: State):
        messages = state["messages"]
        response = await llm_with_tools.ainvoke(messages)
        return {"messages": [response]}

    builder = StateGraph(State)
    builder.add_node("llm_with_tools", llm_node)
    tool_node = ToolNode(tools=tools, handle_tool_errors=True)
    builder.add_node("tools", tool_node)
    builder.add_conditional_edges("llm_with_tools", tools_condition)
    builder.add_edge("tools", "llm_with_tools")
    builder.add_edge(START, "llm_with_tools")
    builder.add_edge("llm_with_tools", END)
    return builder.compile()


# Build basic graph (no tools, no memory)
def create_basic_graph():
    llm = ChatGoogleGenerativeAI(
        model=GEMINI_MODEL,
        google_api_key=GOOGLE_API_KEY,
        temperature=GEMINI_TEMPERATURE,
    )

    async def llm_basic_node(state: State):
        messages = state["messages"]
        system_prompt = SystemMessage(
            content="""You are a helpful and friendly voice AI assistant. Your responses should be:

    - Conversational and natural, as if speaking to a friend
    - Concise but informative - aim for 1-3 sentences unless more detail is specifically requested
    - Clear and easy to understand when spoken aloud
    - Engaging and personable while remaining professional
    - Avoid overly complex language or long lists that are hard to follow in audio format

    When responding:
    - Use a warm, approachable tone
    - Speak in a natural rhythm suitable for text-to-speech
    - If you need to provide multiple items or steps, break them into digestible chunks
    - Ask clarifying questions when needed to better assist the user
    - Acknowledge when you don't know something rather than guessing

    Remember that users are interacting with you through voice, so structure your responses to be easily understood when heard rather than read.
    Dont use abbreviations or numerical content in your responses."""
        )
        if not any(isinstance(m, SystemMessage) for m in messages):
            messages.insert(0, system_prompt)
        return {"messages": [llm.invoke(messages)]}

    builder = StateGraph(State)
    builder.add_node("llm_basic", llm_basic_node)
    builder.set_entry_point("llm_basic")
    builder.add_edge("llm_basic", END)
    return builder.compile()  # No checkpointing