Spaces:
Sleeping
Sleeping
| from langgraph.graph import StateGraph, END, START | |
| from langchain_core.messages import SystemMessage | |
| from typing import TypedDict | |
| # Temporarily disabled due to protobuf issues | |
| # from config import GOOGLE_API_KEY, GEMINI_MODEL, GEMINI_TEMPERATURE | |
| from rag_service import search_docs, search_government_docs, analyze_scenario | |
| from langchain_tavily import TavilySearch | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from typing import Annotated | |
| from typing_extensions import TypedDict | |
| from langgraph.graph.message import add_messages | |
| # Temporarily disabled due to protobuf issues | |
| # from langchain_google_genai import ChatGoogleGenerativeAI | |
| from pydantic import BaseModel, Field, field_validator | |
| from typing import List, Optional, Literal | |
| from langchain_core.tools import tool | |
| import asyncio | |
| # Optional MCP client - install with: pip install langchain-mcp-adapters | |
| try: | |
| from langchain_mcp_adapters.client import MultiServerMCPClient | |
| MCP_AVAILABLE = True | |
| except ImportError: | |
| MCP_AVAILABLE = False | |
| # Optional Tavily search - requires TAVILY_API_KEY environment variable | |
| try: | |
| tavily_search = TavilySearch(max_results=4) | |
| TAVILY_AVAILABLE = True | |
| except Exception: | |
| TAVILY_AVAILABLE = False | |
| def search_tool(query: str): | |
| """ | |
| Perform an advanced web search using the Tavily Search API with hardcoded options. | |
| Parameters: | |
| ---------- | |
| query : str | |
| The search query string. | |
| Returns: | |
| ------- | |
| str | |
| The search results as a string returned by the Tavily Search API. | |
| Raises: | |
| ------ | |
| Exception | |
| Errors during the search are caught and returned as error strings. | |
| """ | |
| if not TAVILY_AVAILABLE: | |
| return "Web search is not available. Tavily API key is not configured." | |
| query_params = {"query": query, "auto_parameters": True} | |
| try: | |
| result = tavily_search.invoke(query_params) | |
| return result | |
| except Exception as e: | |
| return f"Error during Tavily search: {str(e)}" | |
| # State definition | |
| class State(TypedDict): | |
| # add_messages is known as a reducer, where it does not modify the list but adds messages to it | |
| messages: Annotated[list, add_messages] | |
| # messages: Annotated[list[BaseMessage], add_messages] | |
| # both have same result no need to use BaseMessage | |
| async def create_graph(kb_tool: bool, mcp_config: dict): | |
| if mcp_config and MCP_AVAILABLE: | |
| server_config = { | |
| "url": mcp_config["url"], | |
| "transport": "streamable_http", | |
| } | |
| # Add headers if bearer token exists | |
| if mcp_config.get("bearerToken"): | |
| server_config["headers"] = { | |
| "Authorization": f"Bearer {mcp_config['bearerToken']}" | |
| } | |
| client = MultiServerMCPClient({mcp_config["name"]: server_config}) | |
| mcp_tools = await client.get_tools() | |
| else: | |
| mcp_tools = [] | |
| llm = ChatGoogleGenerativeAI( | |
| model=GEMINI_MODEL, | |
| google_api_key=GOOGLE_API_KEY, | |
| temperature=GEMINI_TEMPERATURE, | |
| ) | |
| if kb_tool: | |
| tools = [search_docs, search_government_docs, analyze_scenario, search_tool] | |
| else: | |
| tools = [search_tool, analyze_scenario] | |
| tools = tools + mcp_tools | |
| llm_with_tools = llm.bind_tools(tools) | |
| async def llm_node(state: State): | |
| messages = state["messages"] | |
| response = await llm_with_tools.ainvoke(messages) | |
| return {"messages": [response]} | |
| builder = StateGraph(State) | |
| builder.add_node("llm_with_tools", llm_node) | |
| tool_node = ToolNode(tools=tools, handle_tool_errors=True) | |
| builder.add_node("tools", tool_node) | |
| builder.add_conditional_edges("llm_with_tools", tools_condition) | |
| builder.add_edge("tools", "llm_with_tools") | |
| builder.add_edge(START, "llm_with_tools") | |
| builder.add_edge("llm_with_tools", END) | |
| return builder.compile() | |
| # Build basic graph (no tools, no memory) | |
| def create_basic_graph(): | |
| llm = ChatGoogleGenerativeAI( | |
| model=GEMINI_MODEL, | |
| google_api_key=GOOGLE_API_KEY, | |
| temperature=GEMINI_TEMPERATURE, | |
| ) | |
| async def llm_basic_node(state: State): | |
| messages = state["messages"] | |
| system_prompt = SystemMessage( | |
| content="""You are a helpful and friendly voice AI assistant. Your responses should be: | |
| - Conversational and natural, as if speaking to a friend | |
| - Concise but informative - aim for 1-3 sentences unless more detail is specifically requested | |
| - Clear and easy to understand when spoken aloud | |
| - Engaging and personable while remaining professional | |
| - Avoid overly complex language or long lists that are hard to follow in audio format | |
| When responding: | |
| - Use a warm, approachable tone | |
| - Speak in a natural rhythm suitable for text-to-speech | |
| - If you need to provide multiple items or steps, break them into digestible chunks | |
| - Ask clarifying questions when needed to better assist the user | |
| - Acknowledge when you don't know something rather than guessing | |
| Remember that users are interacting with you through voice, so structure your responses to be easily understood when heard rather than read. | |
| Dont use abbreviations or numerical content in your responses.""" | |
| ) | |
| if not any(isinstance(m, SystemMessage) for m in messages): | |
| messages.insert(0, system_prompt) | |
| return {"messages": [llm.invoke(messages)]} | |
| builder = StateGraph(State) | |
| builder.add_node("llm_basic", llm_basic_node) | |
| builder.set_entry_point("llm_basic") | |
| builder.add_edge("llm_basic", END) | |
| return builder.compile() # No checkpointing | |