Spaces:
Sleeping
Sleeping
Commit
ยท
cf02b2b
1
Parent(s):
edd4785
Deploy clean Voice Bot backend to HF Spaces
Browse files๐ Features:
- FastAPI application optimized for HF Spaces (port 7860)
- Voice processing with ASR and TTS
- LangChain-powered RAG system for document search
- WebSocket support for real-time communication
- JWT authentication
- Hybrid LLM service (Gemini + Groq)
- Docker configuration with health checks
- Clean project structure without deployment artifacts
โ
Ready for HF Spaces deployment
- .gitignore +62 -0
- Dockerfile +39 -0
- README.md +198 -7
- app.py +145 -0
- audio_services.py +88 -0
- auth.py +25 -0
- config.py +51 -0
- document_service.py +171 -0
- enhanced_websocket_handler.py +395 -0
- hybrid_llm_service.py +261 -0
- lancedb_service.py +436 -0
- llm_service.py +155 -0
- main.py +21 -0
- rag_service.py +322 -0
- requirements.txt +32 -0
- voice_service.py +324 -0
- voice_websocket_server.py +492 -0
- websocket_handler.py +403 -0
.gitignore
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*$py.class
|
| 4 |
+
*.so
|
| 5 |
+
.Python
|
| 6 |
+
build/
|
| 7 |
+
develop-eggs/
|
| 8 |
+
dist/
|
| 9 |
+
downloads/
|
| 10 |
+
eggs/
|
| 11 |
+
.eggs/
|
| 12 |
+
lib/
|
| 13 |
+
lib64/
|
| 14 |
+
parts/
|
| 15 |
+
sdist/
|
| 16 |
+
var/
|
| 17 |
+
wheels/
|
| 18 |
+
*.egg-info/
|
| 19 |
+
.installed.cfg
|
| 20 |
+
*.egg
|
| 21 |
+
MANIFEST
|
| 22 |
+
|
| 23 |
+
# PyInstaller
|
| 24 |
+
*.manifest
|
| 25 |
+
*.spec
|
| 26 |
+
|
| 27 |
+
# Virtualenv
|
| 28 |
+
venv/
|
| 29 |
+
ENV/
|
| 30 |
+
env/
|
| 31 |
+
.venv/
|
| 32 |
+
|
| 33 |
+
# IDE
|
| 34 |
+
.vscode/
|
| 35 |
+
.idea/
|
| 36 |
+
*.swp
|
| 37 |
+
*.swo
|
| 38 |
+
*~
|
| 39 |
+
|
| 40 |
+
# Environment files
|
| 41 |
+
.env
|
| 42 |
+
.env.local
|
| 43 |
+
.env.production
|
| 44 |
+
.env.development
|
| 45 |
+
|
| 46 |
+
# Logs
|
| 47 |
+
*.log
|
| 48 |
+
logs/
|
| 49 |
+
|
| 50 |
+
# Database
|
| 51 |
+
*.db
|
| 52 |
+
*.sqlite
|
| 53 |
+
*.sqlite3
|
| 54 |
+
|
| 55 |
+
# LanceDB data
|
| 56 |
+
lancedb_data/
|
| 57 |
+
|
| 58 |
+
# Temporary files
|
| 59 |
+
*.tmp
|
| 60 |
+
*.temp
|
| 61 |
+
.DS_Store
|
| 62 |
+
Thumbs.db
|
Dockerfile
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Python 3.12 as specified
|
| 2 |
+
FROM python:3.12-slim
|
| 3 |
+
|
| 4 |
+
# Install system dependencies
|
| 5 |
+
RUN apt-get update && apt-get install -y \
|
| 6 |
+
curl \
|
| 7 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
# Create a non-root user
|
| 10 |
+
RUN useradd -m -u 1000 user
|
| 11 |
+
USER user
|
| 12 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 13 |
+
|
| 14 |
+
# Set working directory
|
| 15 |
+
WORKDIR /app
|
| 16 |
+
|
| 17 |
+
# Copy requirements first for better Docker layer caching
|
| 18 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
| 19 |
+
|
| 20 |
+
# Install Python dependencies
|
| 21 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 22 |
+
pip install --no-cache-dir --upgrade -r requirements.txt
|
| 23 |
+
|
| 24 |
+
# Copy the application code
|
| 25 |
+
COPY --chown=user . /app
|
| 26 |
+
|
| 27 |
+
# Expose the port that HF Spaces requires
|
| 28 |
+
EXPOSE 7860
|
| 29 |
+
|
| 30 |
+
# Set environment variables
|
| 31 |
+
ENV PYTHONPATH=/app
|
| 32 |
+
ENV PYTHONUNBUFFERED=1
|
| 33 |
+
|
| 34 |
+
# Health check
|
| 35 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
|
| 36 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 37 |
+
|
| 38 |
+
# Run the application
|
| 39 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,10 +1,201 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
sdk: docker
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---------
|
| 2 |
+
|
| 3 |
+
title: PensionBot - Voice Assistant
|
| 4 |
+
|
| 5 |
+
emoji: ๐คtitle: Voice Bot Government Assistanttitle: Rajasthan Pension Assistant
|
| 6 |
+
|
| 7 |
+
colorFrom: blue
|
| 8 |
+
|
| 9 |
+
colorTo: green emoji: ๐คemoji: ๐๏ธ
|
| 10 |
+
|
| 11 |
sdk: docker
|
| 12 |
+
|
| 13 |
+
pinned: falsecolorFrom: bluecolorFrom: blue
|
| 14 |
+
|
| 15 |
+
license: mit
|
| 16 |
+
|
| 17 |
+
app_port: 7860colorTo: green colorTo: purple
|
| 18 |
+
|
| 19 |
---
|
| 20 |
|
| 21 |
+
sdk: dockersdk: gradio
|
| 22 |
+
|
| 23 |
+
# PensionBot - Voice Assistant ๐ค
|
| 24 |
+
|
| 25 |
+
pinned: falsesdk_version: 4.44.0
|
| 26 |
+
|
| 27 |
+
A sophisticated AI-powered voice assistant designed for government pension queries and document searches. Built with FastAPI, this backend provides comprehensive API endpoints for voice interaction, document processing, and intelligent responses.
|
| 28 |
+
|
| 29 |
+
license: mitapp_file: gradio_app.py
|
| 30 |
+
|
| 31 |
+
## ๐ Features
|
| 32 |
+
|
| 33 |
+
app_port: 8000pinned: false
|
| 34 |
+
|
| 35 |
+
- **Voice Processing**: Advanced ASR and TTS capabilities
|
| 36 |
+
|
| 37 |
+
- **Document Search**: RAG-based government document knowledge base---license: mit
|
| 38 |
+
|
| 39 |
+
- **Hybrid AI**: Multiple LLM providers for optimal responses
|
| 40 |
+
|
| 41 |
+
- **WebSocket Support**: Real-time communicationdisable_embedding: false
|
| 42 |
+
|
| 43 |
+
- **Authentication**: JWT-based secure access
|
| 44 |
+
|
| 45 |
+
- **Policy Analysis**: Visual charts and scenario analysis# Voice Bot Government Assistant ๐ค---
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
## ๐ก API Endpoints
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
- `GET /` - Service information and available endpointsA sophisticated AI-powered voice assistant designed for government policy queries and document searches. Built with FastAPI, this backend provides comprehensive API endpoints for voice interaction, document processing, and intelligent responses.# Rajasthan Government Assistant ๏ฟฝ๏ธ
|
| 54 |
+
|
| 55 |
+
- `GET /health` - Health check with service status
|
| 56 |
+
|
| 57 |
+
- `POST /chat` - Text-based conversation interface
|
| 58 |
+
|
| 59 |
+
- `WebSocket /ws` - Real-time voice and text communication
|
| 60 |
+
|
| 61 |
+
- `GET /docs` - Interactive API documentation## ๐ FeaturesA sophisticated AI-powered assistant for Rajasthan government services that combines voice interaction, document search, and intelligent conversation capabilities. Built with FastAPI, LangChain, and advanced RAG (Retrieval-Augmented Generation) technology to help citizens access government information and services.
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
## ๐ Technology Stack
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
- **FastAPI**: High-performance web framework- **Voice Processing**: Advanced ASR and TTS capabilities## Features
|
| 70 |
+
|
| 71 |
+
- **LangChain**: AI orchestration and document processing
|
| 72 |
+
|
| 73 |
+
- **LanceDB**: Vector database for document search- **Document Search**: RAG-based government document knowledge base
|
| 74 |
+
|
| 75 |
+
- **Whisper**: Speech-to-text processing
|
| 76 |
+
|
| 77 |
+
- **Edge-TTS**: Text-to-speech synthesis- **Hybrid AI**: Multiple LLM providers for optimal responses- ๐๏ธ **Voice Interaction**: Speech-to-text and text-to-speech capabilities
|
| 78 |
+
|
| 79 |
+
- **WebSocket**: Real-time communication
|
| 80 |
+
|
| 81 |
+
- **WebSocket Support**: Real-time communication- ๐ **Document Search**: Advanced RAG system with government document knowledge
|
| 82 |
+
|
| 83 |
+
## ๐ Usage
|
| 84 |
+
|
| 85 |
+
- **Authentication**: JWT-based secure access- ๐ค **Hybrid LLM**: Combines multiple AI models for optimal responses
|
| 86 |
+
|
| 87 |
+
The API is accessible at the base URL of this space. Use the `/docs` endpoint to explore the interactive API documentation.
|
| 88 |
+
|
| 89 |
+
- **Policy Analysis**: Visual charts and scenario analysis- ๐ **Scenario Analysis**: Policy impact simulation and analysis
|
| 90 |
+
|
| 91 |
+
### Example Usage:
|
| 92 |
+
|
| 93 |
+
- ๐ **Chart Generation**: Visual policy impact charts
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
|
| 97 |
+
# Health check## ๐ก API Endpoints- ๐ **WebSocket Support**: Real-time communication
|
| 98 |
+
|
| 99 |
+
curl https://chabhishek28-pensionbot.hf.space/health
|
| 100 |
+
|
| 101 |
+
- ๐ก๏ธ **Authentication**: JWT-based user authentication
|
| 102 |
+
|
| 103 |
+
# Chat endpoint
|
| 104 |
+
|
| 105 |
+
curl -X POST https://chabhishek28-pensionbot.hf.space/chat \- `GET /` - Service information and available endpoints
|
| 106 |
+
|
| 107 |
+
-H "Content-Type: application/json" \
|
| 108 |
+
|
| 109 |
+
-d '{"message": "Tell me about pension policies"}'- `GET /health` - Health check with service status## API Endpoints
|
| 110 |
+
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
- `POST /chat` - Text-based conversation interface
|
| 114 |
+
|
| 115 |
+
## ๐ Environment Variables
|
| 116 |
+
|
| 117 |
+
- `WebSocket /ws` - Real-time voice and text communication- `GET /`: Root endpoint with service information
|
| 118 |
+
|
| 119 |
+
The following environment variables are required:
|
| 120 |
+
|
| 121 |
+
- `GET /docs` - Interactive API documentation- `GET /health`: Health check for all services
|
| 122 |
+
|
| 123 |
+
- `GOOGLE_API_KEY`: Google Gemini API key
|
| 124 |
+
|
| 125 |
+
- `GROQ_API_KEY`: Groq API key for Whisper- `POST /chat`: Text-based chat interface
|
| 126 |
+
|
| 127 |
+
- `TAVILY_API_KEY`: Tavily search API key
|
| 128 |
+
|
| 129 |
+
- `JWT_SECRET_KEY`: JWT authentication secret## ๐ Technology Stack- `POST /search`: Document search functionality
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
## ๐ Security- `WebSocket /ws`: Real-time voice and text communication
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
This API includes JWT-based authentication for secure access to protected endpoints.- **FastAPI**: High-performance web framework
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
## ๐ License- **LangChain**: AI orchestration and document processing## Technology Stack
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
MIT License - see LICENSE for details.- **LanceDB**: Vector database for document search
|
| 146 |
+
|
| 147 |
+
- **Whisper**: Speech-to-text processing- **Backend**: FastAPI, Python 3.11+
|
| 148 |
+
|
| 149 |
+
- **Edge-TTS**: Text-to-speech synthesis- **AI/ML**: LangChain, Hugging Face Transformers, Sentence Transformers
|
| 150 |
+
|
| 151 |
+
- **WebSocket**: Real-time communication- **Vector Database**: LanceDB
|
| 152 |
+
|
| 153 |
+
- **Voice**: Whisper ASR, Edge TTS
|
| 154 |
+
|
| 155 |
+
## ๐ Usage- **Authentication**: JWT tokens
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
The API is accessible at the base URL of this space. Use the `/docs` endpoint to explore the interactive API documentation.## Environment Variables
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
### Example Usage:Set these in your Hugging Face Space secrets:
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
```bash- `GOOGLE_API_KEY`: For Gemini AI model
|
| 168 |
+
|
| 169 |
+
# Health check- `GROQ_API_KEY`: For Groq AI model
|
| 170 |
+
|
| 171 |
+
curl https://your-space-name.hf.space/health- `TAVILY_API_KEY`: For search capabilities
|
| 172 |
+
|
| 173 |
+
- `JWT_SECRET_KEY`: For authentication
|
| 174 |
+
|
| 175 |
+
# Chat endpoint
|
| 176 |
+
|
| 177 |
+
curl -X POST https://your-space-name.hf.space/chat \## Usage
|
| 178 |
+
|
| 179 |
+
-H "Content-Type: application/json" \
|
| 180 |
+
|
| 181 |
+
-d '{"message": "Tell me about government policies"}'Once deployed, the API will be available at your Hugging Face Space URL. Use the WebSocket endpoint for real-time voice interaction or the REST endpoints for text-based communication.
|
| 182 |
+
|
| 183 |
+
```# Updated for HF Spaces deployment
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
## ๐ Environment Variables
|
| 187 |
+
|
| 188 |
+
The following environment variables are required:
|
| 189 |
+
|
| 190 |
+
- `GOOGLE_API_KEY`: Google Gemini API key
|
| 191 |
+
- `GROQ_API_KEY`: Groq API key for Whisper
|
| 192 |
+
- `TAVILY_API_KEY`: Tavily search API key
|
| 193 |
+
- `JWT_SECRET_KEY`: JWT authentication secret
|
| 194 |
+
|
| 195 |
+
## ๐ Security
|
| 196 |
+
|
| 197 |
+
This API includes JWT-based authentication for secure access to protected endpoints.
|
| 198 |
+
|
| 199 |
+
## ๐ License
|
| 200 |
+
|
| 201 |
+
MIT License - see LICENSE for details.
|
app.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from contextlib import asynccontextmanager
|
| 5 |
+
from fastapi import FastAPI, WebSocket, HTTPException
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
from fastapi.responses import JSONResponse
|
| 8 |
+
from websocket_handler import handle_websocket_connection
|
| 9 |
+
from enhanced_websocket_handler import handle_enhanced_websocket_connection
|
| 10 |
+
from hybrid_llm_service import HybridLLMService
|
| 11 |
+
from voice_service import VoiceService
|
| 12 |
+
from rag_service import search_documents
|
| 13 |
+
from lancedb_service import LanceDBService
|
| 14 |
+
import config
|
| 15 |
+
from dotenv import load_dotenv
|
| 16 |
+
|
| 17 |
+
# MCP and Authentication imports
|
| 18 |
+
from fastapi import Depends
|
| 19 |
+
from pydantic import BaseModel
|
| 20 |
+
from typing import Optional
|
| 21 |
+
from auth import get_current_user
|
| 22 |
+
|
| 23 |
+
# Load environment variables
|
| 24 |
+
load_dotenv()
|
| 25 |
+
|
| 26 |
+
# Configure logging
|
| 27 |
+
logging.basicConfig(
|
| 28 |
+
level=logging.INFO,
|
| 29 |
+
format='%(asctime)s [%(levelname)s] %(message)s',
|
| 30 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 31 |
+
)
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
# Get configuration
|
| 35 |
+
config_dict = {
|
| 36 |
+
"ALLOWED_ORIGINS": config.ALLOWED_ORIGINS,
|
| 37 |
+
"ENABLE_VOICE_FEATURES": config.ENABLE_VOICE_FEATURES
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
@asynccontextmanager
|
| 41 |
+
async def lifespan(app: FastAPI):
|
| 42 |
+
"""Application lifespan handler"""
|
| 43 |
+
# Startup
|
| 44 |
+
logger.info("๐ Starting Voice Bot Application...")
|
| 45 |
+
logger.info("โ
Application started successfully")
|
| 46 |
+
yield
|
| 47 |
+
# Shutdown (if needed)
|
| 48 |
+
logger.info("๐ Shutting down Voice Bot Application...")
|
| 49 |
+
|
| 50 |
+
# Create FastAPI application
|
| 51 |
+
app = FastAPI(
|
| 52 |
+
title="Voice Bot Government Assistant",
|
| 53 |
+
description="AI-powered voice assistant for government policies and services",
|
| 54 |
+
version="1.0.0",
|
| 55 |
+
lifespan=lifespan
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Configure CORS
|
| 59 |
+
app.add_middleware(
|
| 60 |
+
CORSMiddleware,
|
| 61 |
+
allow_origins=config.ALLOWED_ORIGINS,
|
| 62 |
+
allow_credentials=True,
|
| 63 |
+
allow_methods=["*"],
|
| 64 |
+
allow_headers=["*"],
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Initialize services (lazy loading for HF Spaces)
|
| 68 |
+
llm_service = None
|
| 69 |
+
voice_service = None
|
| 70 |
+
lancedb_service = None
|
| 71 |
+
|
| 72 |
+
def get_llm_service():
|
| 73 |
+
global llm_service
|
| 74 |
+
if llm_service is None:
|
| 75 |
+
llm_service = HybridLLMService()
|
| 76 |
+
return llm_service
|
| 77 |
+
|
| 78 |
+
def get_voice_service():
|
| 79 |
+
global voice_service
|
| 80 |
+
if voice_service is None:
|
| 81 |
+
voice_service = VoiceService()
|
| 82 |
+
return voice_service
|
| 83 |
+
|
| 84 |
+
def get_lancedb_service():
|
| 85 |
+
global lancedb_service
|
| 86 |
+
if lancedb_service is None:
|
| 87 |
+
lancedb_service = LanceDBService()
|
| 88 |
+
return lancedb_service
|
| 89 |
+
|
| 90 |
+
# Health check endpoint
|
| 91 |
+
@app.get("/health")
|
| 92 |
+
async def health_check():
|
| 93 |
+
"""Health check endpoint"""
|
| 94 |
+
return {
|
| 95 |
+
"status": "healthy",
|
| 96 |
+
"service": "voice-bot-api",
|
| 97 |
+
"timestamp": datetime.now().isoformat(),
|
| 98 |
+
"version": "1.0.0"
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
# Root endpoint
|
| 102 |
+
@app.get("/")
|
| 103 |
+
async def root():
|
| 104 |
+
"""Root endpoint with service information"""
|
| 105 |
+
return {
|
| 106 |
+
"message": "Voice Bot Government Assistant API",
|
| 107 |
+
"status": "running",
|
| 108 |
+
"version": "1.0.0",
|
| 109 |
+
"endpoints": {
|
| 110 |
+
"health": "/health",
|
| 111 |
+
"chat": "/chat",
|
| 112 |
+
"websocket": "/ws",
|
| 113 |
+
"docs": "/docs"
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
# Chat endpoint
|
| 118 |
+
@app.post("/chat")
|
| 119 |
+
async def chat_endpoint(request: dict):
|
| 120 |
+
"""Text-based chat endpoint"""
|
| 121 |
+
try:
|
| 122 |
+
message = request.get("message", "")
|
| 123 |
+
if not message:
|
| 124 |
+
raise HTTPException(status_code=400, detail="Message is required")
|
| 125 |
+
|
| 126 |
+
llm = get_llm_service()
|
| 127 |
+
response = await llm.get_response(message)
|
| 128 |
+
|
| 129 |
+
return {
|
| 130 |
+
"response": response,
|
| 131 |
+
"timestamp": datetime.now().isoformat()
|
| 132 |
+
}
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.error(f"Chat error: {str(e)}")
|
| 135 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 136 |
+
|
| 137 |
+
# WebSocket endpoint
|
| 138 |
+
@app.websocket("/ws")
|
| 139 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 140 |
+
"""WebSocket endpoint for real-time communication"""
|
| 141 |
+
await handle_enhanced_websocket_connection(websocket)
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
import uvicorn
|
| 145 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
audio_services.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from groq import AsyncGroq
|
| 2 |
+
from config import GROQ_API_KEY, ASR_MODEL, MURF_API_KEY
|
| 3 |
+
import soundfile as sf
|
| 4 |
+
import numpy as np
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 7 |
+
from murf import AsyncMurf
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
groq = AsyncGroq(api_key=GROQ_API_KEY)
|
| 11 |
+
executor = ThreadPoolExecutor(max_workers=4)
|
| 12 |
+
|
| 13 |
+
# kokoro_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
+
# kokoro_model = KModel().to(kokoro_device).eval()
|
| 15 |
+
# model_path = hf_hub_download(repo_id='hexgrad/Kokoro-82M', filename="kokoro-v1_0.pth")
|
| 16 |
+
# kokoro_model.load_state_dict(torch.load(model_path, map_location=kokoro_device), strict=False)
|
| 17 |
+
# kokoro_pipeline = KPipeline(lang_code='a', model=False)
|
| 18 |
+
# voice_path = hf_hub_download("hexgrad/Kokoro-82M", "voices/af_heart.pt")
|
| 19 |
+
# kokoro_voice = torch.load(voice_path, weights_only=True).to(kokoro_device)
|
| 20 |
+
|
| 21 |
+
async def groq_asr_bytes(audio_bytes: bytes, model: str = ASR_MODEL, language: str = "en") -> str:
|
| 22 |
+
"""Transcribes audio using Groq ASR."""
|
| 23 |
+
# Groq client is already async, so we can use it directly
|
| 24 |
+
resp = await groq.audio.transcriptions.create(
|
| 25 |
+
model=model,
|
| 26 |
+
file=("audio.wav", audio_bytes, "audio/wav"),
|
| 27 |
+
response_format="text",
|
| 28 |
+
language=language
|
| 29 |
+
)
|
| 30 |
+
return resp
|
| 31 |
+
|
| 32 |
+
murf_client = AsyncMurf(api_key=MURF_API_KEY)
|
| 33 |
+
|
| 34 |
+
async def murf_tts(text: str, voice_id: str = "en-IN-isha", format: str = "MP3") -> bytes:
|
| 35 |
+
resp = murf_client.text_to_speech.stream(
|
| 36 |
+
text=text,
|
| 37 |
+
voice_id=voice_id,
|
| 38 |
+
format=format,
|
| 39 |
+
sample_rate=44100.0
|
| 40 |
+
)
|
| 41 |
+
chunks = [chunk async for chunk in resp]
|
| 42 |
+
full_audio = b''.join(chunks)
|
| 43 |
+
return full_audio
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# def groq_tts(text: str, speed: float = 1.0) -> bytes:
|
| 56 |
+
# try:
|
| 57 |
+
# audio_segments = []
|
| 58 |
+
# for _, ps, _ in kokoro_pipeline(text, kokoro_voice, speed):
|
| 59 |
+
# ref_s = kokoro_voice[len(ps) - 1]
|
| 60 |
+
# audio = kokoro_model(ps, ref_s, speed)
|
| 61 |
+
# audio_np = audio.cpu().numpy().astype(np.float32)
|
| 62 |
+
# audio_segments.append(audio_np)
|
| 63 |
+
|
| 64 |
+
# full_audio = np.concatenate(audio_segments)
|
| 65 |
+
|
| 66 |
+
# # Write to WAV bytes
|
| 67 |
+
# buf = io.BytesIO()
|
| 68 |
+
# sf.write(buf, full_audio, samplerate=24000, format="WAV", subtype="PCM_16")
|
| 69 |
+
# buf.seek(0)
|
| 70 |
+
# return buf.read()
|
| 71 |
+
|
| 72 |
+
# except Exception as e:
|
| 73 |
+
# print("Kokoro TTS synthesis failed")
|
| 74 |
+
# raise RuntimeError(f"Kokoro TTS failed: {e}")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
'''def groq_tts(text: str, model: str = TTS_MODEL, voice: str = TTS_VOICE) -> bytes:
|
| 78 |
+
text = text[:1000]
|
| 79 |
+
resp = groq.audio.speech.create(
|
| 80 |
+
model=model,
|
| 81 |
+
voice=voice,
|
| 82 |
+
input=text,
|
| 83 |
+
response_format="wav"
|
| 84 |
+
)
|
| 85 |
+
print(resp.read()[:10])
|
| 86 |
+
return resp.read()
|
| 87 |
+
'''
|
| 88 |
+
|
auth.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jwt
|
| 2 |
+
from fastapi import HTTPException, status, Header
|
| 3 |
+
from jwt import PyJWTError
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
load_dotenv()
|
| 8 |
+
|
| 9 |
+
SUPABASE_JWT_SECRET = os.getenv("SUPABASE_JWT_SECRET")
|
| 10 |
+
def verify_token(token: str):
|
| 11 |
+
try:
|
| 12 |
+
payload = jwt.decode(token, SUPABASE_JWT_SECRET, algorithms=["HS256"], audience="authenticated")
|
| 13 |
+
return payload
|
| 14 |
+
except PyJWTError:
|
| 15 |
+
raise HTTPException(
|
| 16 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 17 |
+
detail="Invalid authentication credentials",
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
def get_current_user(authorization: str = Header(...)):
|
| 21 |
+
if not authorization.startswith("Bearer "):
|
| 22 |
+
raise HTTPException(status_code=401, detail="Invalid Authorization header")
|
| 23 |
+
token = authorization.split(" ")[1]
|
| 24 |
+
payload = verify_token(token)
|
| 25 |
+
return payload
|
config.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
load_dotenv()
|
| 5 |
+
|
| 6 |
+
# API Configuration
|
| 7 |
+
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
|
| 8 |
+
GEMINI_API_KEY = os.environ.get("GOOGLE_API_KEY") # Backward compatibility
|
| 9 |
+
|
| 10 |
+
# LangSmith Configuration (optional)
|
| 11 |
+
LANGSMITH_API_KEY = os.environ.get("LANGSMITH_API_KEY")
|
| 12 |
+
LANGCHAIN_TRACING_V2 = os.environ.get("LANGCHAIN_TRACING_V2", "false").lower() == "true"
|
| 13 |
+
LANGCHAIN_PROJECT = os.environ.get("LANGCHAIN_PROJECT", "voice-bot-government-docs")
|
| 14 |
+
|
| 15 |
+
# Hybrid LLM Configuration
|
| 16 |
+
USE_HYBRID_LLM = os.environ.get("USE_HYBRID_LLM", "false").lower() == "true"
|
| 17 |
+
FAST_LLM_PROVIDER = os.environ.get("FAST_LLM_PROVIDER", "groq")
|
| 18 |
+
COMPLEX_LLM_PROVIDER = os.environ.get("COMPLEX_LLM_PROVIDER", "gemini")
|
| 19 |
+
|
| 20 |
+
# Groq Configuration
|
| 21 |
+
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
|
| 22 |
+
GROQ_MODEL = os.environ.get("GROQ_MODEL", "llama-3.1-70b-versatile")
|
| 23 |
+
|
| 24 |
+
# Gemini Model Configuration
|
| 25 |
+
GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-1.5-pro-latest")
|
| 26 |
+
GEMINI_TEMPERATURE = float(os.environ.get("GEMINI_TEMPERATURE", "0.7"))
|
| 27 |
+
|
| 28 |
+
# Voice Features Configuration
|
| 29 |
+
ENABLE_VOICE_FEATURES = os.environ.get("ENABLE_VOICE_FEATURES", "false").lower() == "true"
|
| 30 |
+
TTS_PROVIDER = os.environ.get("TTS_PROVIDER", "edge-tts")
|
| 31 |
+
ASR_PROVIDER = os.environ.get("ASR_PROVIDER", "whisper")
|
| 32 |
+
VOICE_LANGUAGE = os.environ.get("VOICE_LANGUAGE", "en-US")
|
| 33 |
+
DEFAULT_VOICE_SPEED = float(os.environ.get("DEFAULT_VOICE_SPEED", "1.0"))
|
| 34 |
+
|
| 35 |
+
# Embedding Model Configuration
|
| 36 |
+
EMBEDDING_MODEL_NAME = os.environ.get("EMBEDDING_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2")
|
| 37 |
+
EMBEDDING_SIZE = 768
|
| 38 |
+
|
| 39 |
+
# Text Processing Configuration
|
| 40 |
+
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1000"))
|
| 41 |
+
CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "200"))
|
| 42 |
+
|
| 43 |
+
# CORS Configuration
|
| 44 |
+
ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "*").split(",") if os.environ.get("ALLOWED_ORIGINS") != "*" else ["*"]
|
| 45 |
+
|
| 46 |
+
# LanceDB Configuration
|
| 47 |
+
LANCEDB_PATH = os.environ.get("LANCEDB_PATH", "./lancedb_data")
|
| 48 |
+
|
| 49 |
+
# JWT Configuration
|
| 50 |
+
JWT_SECRET_KEY = os.environ.get("JWT_SECRET_KEY")
|
| 51 |
+
JWT_ALGORITHM = os.environ.get("JWT_ALGORITHM", "HS256")
|
document_service.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import UploadFile
|
| 2 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 3 |
+
from langchain.docstore.document import Document
|
| 4 |
+
import pdfplumber
|
| 5 |
+
import os
|
| 6 |
+
import asyncio
|
| 7 |
+
from typing import List
|
| 8 |
+
from lancedb_service import lancedb_service
|
| 9 |
+
from config import CHUNK_SIZE, CHUNK_OVERLAP
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
|
| 12 |
+
def read_pdf(file: UploadFile) -> str:
|
| 13 |
+
with pdfplumber.open(file.file) as pdf:
|
| 14 |
+
text = "\n".join(page.extract_text() or "" for page in pdf.pages)
|
| 15 |
+
return text
|
| 16 |
+
|
| 17 |
+
async def process_document_upload(file: UploadFile, userid: str, knowledge_base: str):
|
| 18 |
+
try:
|
| 19 |
+
filename = file.filename
|
| 20 |
+
if not filename.lower().endswith(".pdf"):
|
| 21 |
+
return {"error": "Only PDF files are supported"}
|
| 22 |
+
|
| 23 |
+
# Read PDF
|
| 24 |
+
with pdfplumber.open(file.file) as pdf:
|
| 25 |
+
text = "\n".join(page.extract_text() or "" for page in pdf.pages)
|
| 26 |
+
|
| 27 |
+
# Chunk text
|
| 28 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
|
| 29 |
+
chunks = splitter.split_text(text)
|
| 30 |
+
|
| 31 |
+
# Batch create Document objects with metadata including knowledge base
|
| 32 |
+
upload_date = datetime.now().isoformat()
|
| 33 |
+
docs = [
|
| 34 |
+
Document(
|
| 35 |
+
page_content=chunk,
|
| 36 |
+
metadata={
|
| 37 |
+
"source": filename,
|
| 38 |
+
"userid": userid,
|
| 39 |
+
"knowledge_base": knowledge_base,
|
| 40 |
+
"upload_date": upload_date
|
| 41 |
+
}
|
| 42 |
+
)
|
| 43 |
+
for chunk in chunks
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
# โ
Batch embed & insert using LanceDB
|
| 47 |
+
await lancedb_service.add_documents(docs, userid, knowledge_base, filename)
|
| 48 |
+
|
| 49 |
+
return {
|
| 50 |
+
"status": "uploaded",
|
| 51 |
+
"chunks": len(docs),
|
| 52 |
+
"file": filename,
|
| 53 |
+
"knowledge_base": knowledge_base
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
except Exception as e:
|
| 57 |
+
return {"error": str(e)}
|
| 58 |
+
|
| 59 |
+
def read_pdf_from_path(pdf_path: str) -> str:
|
| 60 |
+
"""Read PDF content from file path"""
|
| 61 |
+
try:
|
| 62 |
+
with pdfplumber.open(pdf_path) as pdf:
|
| 63 |
+
text = "\n".join(page.extract_text() or "" for page in pdf.pages)
|
| 64 |
+
return text
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"Error reading PDF {pdf_path}: {str(e)}")
|
| 67 |
+
return ""
|
| 68 |
+
|
| 69 |
+
async def process_documents_from_folder(folder_path: str, userid: str = "system", knowledge_base: str = "government_docs"):
|
| 70 |
+
"""Process all PDF documents from the specified folder"""
|
| 71 |
+
try:
|
| 72 |
+
if not os.path.exists(folder_path):
|
| 73 |
+
return {"error": f"Folder path {folder_path} does not exist"}
|
| 74 |
+
|
| 75 |
+
pdf_files = [f for f in os.listdir(folder_path) if f.lower().endswith('.pdf')]
|
| 76 |
+
|
| 77 |
+
if not pdf_files:
|
| 78 |
+
return {"error": "No PDF files found in the folder"}
|
| 79 |
+
|
| 80 |
+
processed_files = []
|
| 81 |
+
total_chunks = 0
|
| 82 |
+
|
| 83 |
+
for pdf_file in pdf_files:
|
| 84 |
+
pdf_path = os.path.join(folder_path, pdf_file)
|
| 85 |
+
|
| 86 |
+
# Read PDF content
|
| 87 |
+
text = read_pdf_from_path(pdf_path)
|
| 88 |
+
|
| 89 |
+
if not text.strip():
|
| 90 |
+
print(f"Skipping {pdf_file} - no text content extracted")
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
# Chunk text
|
| 94 |
+
splitter = RecursiveCharacterTextSplitter(
|
| 95 |
+
chunk_size=CHUNK_SIZE,
|
| 96 |
+
chunk_overlap=CHUNK_OVERLAP
|
| 97 |
+
)
|
| 98 |
+
chunks = splitter.split_text(text)
|
| 99 |
+
|
| 100 |
+
# Create Document objects with metadata
|
| 101 |
+
upload_date = datetime.now().isoformat()
|
| 102 |
+
docs = [
|
| 103 |
+
Document(
|
| 104 |
+
page_content=chunk,
|
| 105 |
+
metadata={
|
| 106 |
+
"source": pdf_file,
|
| 107 |
+
"userid": userid,
|
| 108 |
+
"knowledge_base": knowledge_base,
|
| 109 |
+
"upload_date": upload_date,
|
| 110 |
+
"file_path": pdf_path
|
| 111 |
+
}
|
| 112 |
+
)
|
| 113 |
+
for chunk in chunks
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
# Add documents to LanceDB
|
| 117 |
+
await lancedb_service.add_documents(docs, userid, knowledge_base, pdf_file)
|
| 118 |
+
|
| 119 |
+
processed_files.append({
|
| 120 |
+
"file": pdf_file,
|
| 121 |
+
"chunks": len(chunks)
|
| 122 |
+
})
|
| 123 |
+
total_chunks += len(chunks)
|
| 124 |
+
|
| 125 |
+
print(f"Processed {pdf_file}: {len(chunks)} chunks")
|
| 126 |
+
|
| 127 |
+
return {
|
| 128 |
+
"status": "success",
|
| 129 |
+
"processed_files": len(processed_files),
|
| 130 |
+
"total_chunks": total_chunks,
|
| 131 |
+
"files": processed_files,
|
| 132 |
+
"knowledge_base": knowledge_base
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
return {"error": str(e)}
|
| 137 |
+
|
| 138 |
+
async def initialize_document_database():
|
| 139 |
+
"""Initialize the document database with documents from the aa folder"""
|
| 140 |
+
# Path to the documents folder
|
| 141 |
+
documents_folder = "/Users/abhishekchoudhary/Abhi Project/aa/raw_documents/Documents"
|
| 142 |
+
|
| 143 |
+
print("Starting document database initialization...")
|
| 144 |
+
result = await process_documents_from_folder(
|
| 145 |
+
folder_path=documents_folder,
|
| 146 |
+
userid="system",
|
| 147 |
+
knowledge_base="government_docs"
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if "error" in result:
|
| 151 |
+
print(f"Error initializing database: {result['error']}")
|
| 152 |
+
else:
|
| 153 |
+
print(f"Successfully initialized database with {result['total_chunks']} chunks from {result['processed_files']} files")
|
| 154 |
+
|
| 155 |
+
return result
|
| 156 |
+
|
| 157 |
+
async def get_available_knowledge_bases() -> List[str]:
|
| 158 |
+
"""Get list of available knowledge bases"""
|
| 159 |
+
try:
|
| 160 |
+
return await lancedb_service.get_knowledge_bases()
|
| 161 |
+
except Exception as e:
|
| 162 |
+
print(f"Error getting knowledge bases: {str(e)}")
|
| 163 |
+
return []
|
| 164 |
+
|
| 165 |
+
async def get_documents_by_knowledge_base(knowledge_base: str) -> List[dict]:
|
| 166 |
+
"""Get list of documents in a specific knowledge base"""
|
| 167 |
+
try:
|
| 168 |
+
return await lancedb_service.get_documents_by_knowledge_base(knowledge_base)
|
| 169 |
+
except Exception as e:
|
| 170 |
+
print(f"Error getting documents for knowledge base {knowledge_base}: {str(e)}")
|
| 171 |
+
return []
|
enhanced_websocket_handler.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Enhanced WebSocket handler with hybrid LLM and optional voice features
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import WebSocket, WebSocketDisconnect
|
| 6 |
+
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
| 7 |
+
import logging
|
| 8 |
+
import json
|
| 9 |
+
import asyncio
|
| 10 |
+
import uuid
|
| 11 |
+
import tempfile
|
| 12 |
+
import base64
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
from llm_service import create_graph, create_basic_graph
|
| 16 |
+
from lancedb_service import lancedb_service
|
| 17 |
+
from hybrid_llm_service import HybridLLMService
|
| 18 |
+
from voice_service import voice_service
|
| 19 |
+
from rag_service import search_government_docs
|
| 20 |
+
|
| 21 |
+
# Initialize hybrid LLM service
|
| 22 |
+
hybrid_llm_service = HybridLLMService()
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("voicebot")
|
| 25 |
+
|
| 26 |
+
async def handle_enhanced_websocket_connection(websocket: WebSocket):
|
| 27 |
+
"""Enhanced WebSocket handler with hybrid LLM and voice features"""
|
| 28 |
+
await websocket.accept()
|
| 29 |
+
logger.info("๐ Enhanced WebSocket client connected.")
|
| 30 |
+
|
| 31 |
+
# Initialize session data
|
| 32 |
+
session_data = {
|
| 33 |
+
"messages": [],
|
| 34 |
+
"user_preferences": {
|
| 35 |
+
"voice_enabled": False,
|
| 36 |
+
"preferred_voice": "en-US-AriaNeural",
|
| 37 |
+
"response_mode": "text" # text, voice, both
|
| 38 |
+
},
|
| 39 |
+
"context": ""
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
# Get initial connection data
|
| 44 |
+
initial_data = await websocket.receive_json()
|
| 45 |
+
|
| 46 |
+
# Extract user preferences
|
| 47 |
+
if "preferences" in initial_data:
|
| 48 |
+
session_data["user_preferences"].update(initial_data["preferences"])
|
| 49 |
+
|
| 50 |
+
# Setup user session
|
| 51 |
+
flag = "user_id" in initial_data
|
| 52 |
+
graph = None # Initialize graph variable
|
| 53 |
+
|
| 54 |
+
if flag:
|
| 55 |
+
thread_id = initial_data.get("user_id")
|
| 56 |
+
knowledge_base = initial_data.get("knowledge_base", "government_docs")
|
| 57 |
+
|
| 58 |
+
# Use hybrid LLM or traditional graph based on configuration
|
| 59 |
+
if hybrid_llm_service.use_hybrid:
|
| 60 |
+
logger.info("๐ค Using Hybrid LLM Service")
|
| 61 |
+
use_hybrid = True
|
| 62 |
+
else:
|
| 63 |
+
graph = await create_graph(kb_tool=True, mcp_config=None)
|
| 64 |
+
use_hybrid = False
|
| 65 |
+
|
| 66 |
+
config = {
|
| 67 |
+
"configurable": {
|
| 68 |
+
"thread_id": thread_id,
|
| 69 |
+
"knowledge_base": knowledge_base,
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
else:
|
| 73 |
+
# Basic setup for unauthenticated users
|
| 74 |
+
thread_id = str(uuid.uuid4())
|
| 75 |
+
knowledge_base = "government_docs"
|
| 76 |
+
use_hybrid = hybrid_llm_service.use_hybrid
|
| 77 |
+
|
| 78 |
+
if not use_hybrid:
|
| 79 |
+
graph = create_basic_graph()
|
| 80 |
+
|
| 81 |
+
config = {"configurable": {"thread_id": thread_id}}
|
| 82 |
+
|
| 83 |
+
# Send initial greeting with voice/hybrid capabilities
|
| 84 |
+
await send_enhanced_greeting(websocket, session_data)
|
| 85 |
+
|
| 86 |
+
# Main message handling loop
|
| 87 |
+
while True:
|
| 88 |
+
try:
|
| 89 |
+
data = await websocket.receive_json()
|
| 90 |
+
|
| 91 |
+
if data["type"] == "text_message":
|
| 92 |
+
await handle_text_message(
|
| 93 |
+
websocket, data, session_data,
|
| 94 |
+
use_hybrid, config, knowledge_base, graph
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
elif data["type"] == "voice_message":
|
| 98 |
+
await handle_voice_message(
|
| 99 |
+
websocket, data, session_data,
|
| 100 |
+
use_hybrid, config, knowledge_base, graph
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
elif data["type"] == "preferences_update":
|
| 104 |
+
await handle_preferences_update(websocket, data, session_data)
|
| 105 |
+
|
| 106 |
+
elif data["type"] == "get_voice_status":
|
| 107 |
+
await websocket.send_json({
|
| 108 |
+
"type": "voice_status",
|
| 109 |
+
"data": voice_service.get_voice_status()
|
| 110 |
+
})
|
| 111 |
+
|
| 112 |
+
elif data["type"] == "get_llm_status":
|
| 113 |
+
await websocket.send_json({
|
| 114 |
+
"type": "llm_status",
|
| 115 |
+
"data": hybrid_llm_service.get_provider_info()
|
| 116 |
+
})
|
| 117 |
+
|
| 118 |
+
except WebSocketDisconnect:
|
| 119 |
+
logger.info("๐ WebSocket client disconnected.")
|
| 120 |
+
break
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.error(f"โ Error handling message: {e}")
|
| 123 |
+
await websocket.send_json({
|
| 124 |
+
"type": "error",
|
| 125 |
+
"message": f"An error occurred: {str(e)}"
|
| 126 |
+
})
|
| 127 |
+
|
| 128 |
+
except WebSocketDisconnect:
|
| 129 |
+
logger.info("๐ WebSocket client disconnected during setup.")
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"โ WebSocket error: {e}")
|
| 132 |
+
try:
|
| 133 |
+
await websocket.send_json({
|
| 134 |
+
"type": "error",
|
| 135 |
+
"message": f"Connection error: {str(e)}"
|
| 136 |
+
})
|
| 137 |
+
except:
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
async def send_enhanced_greeting(websocket: WebSocket, session_data: dict):
|
| 141 |
+
"""Send enhanced greeting with system capabilities"""
|
| 142 |
+
|
| 143 |
+
# Get system status
|
| 144 |
+
llm_info = hybrid_llm_service.get_provider_info()
|
| 145 |
+
voice_status = voice_service.get_voice_status()
|
| 146 |
+
|
| 147 |
+
greeting_text = f"""๐ค Welcome to the Government Document Assistant!
|
| 148 |
+
|
| 149 |
+
I'm powered by a hybrid AI system that can help you with:
|
| 150 |
+
โข Government policies and procedures
|
| 151 |
+
โข Document search and analysis
|
| 152 |
+
โข Scenario analysis with visualizations
|
| 153 |
+
โข Quick answers and detailed explanations
|
| 154 |
+
|
| 155 |
+
Current capabilities:
|
| 156 |
+
โข LLM: {'Hybrid (' + llm_info['fast_provider'] + '/' + llm_info['complex_provider'] + ')' if llm_info['hybrid_enabled'] else 'Single provider'}
|
| 157 |
+
โข Voice features: {'Enabled' if voice_status['voice_enabled'] else 'Disabled'}
|
| 158 |
+
|
| 159 |
+
How can I assist you today? You can ask me about any government policies, procedures, or documents!"""
|
| 160 |
+
|
| 161 |
+
# Send text greeting
|
| 162 |
+
await websocket.send_json({
|
| 163 |
+
"type": "message_response",
|
| 164 |
+
"message": greeting_text,
|
| 165 |
+
"provider_used": "system",
|
| 166 |
+
"capabilities": {
|
| 167 |
+
"hybrid_llm": llm_info['hybrid_enabled'],
|
| 168 |
+
"voice_features": voice_status['voice_enabled'],
|
| 169 |
+
"scenario_analysis": True
|
| 170 |
+
}
|
| 171 |
+
})
|
| 172 |
+
|
| 173 |
+
# Send voice greeting if enabled
|
| 174 |
+
if session_data["user_preferences"]["voice_enabled"] and voice_status['voice_enabled']:
|
| 175 |
+
voice_greeting = "Welcome to the Government Document Assistant! I can help you with policies, procedures, and document analysis. How can I assist you today?"
|
| 176 |
+
audio_data = await voice_service.text_to_speech(voice_greeting)
|
| 177 |
+
|
| 178 |
+
if audio_data:
|
| 179 |
+
await websocket.send_json({
|
| 180 |
+
"type": "audio_response",
|
| 181 |
+
"audio_data": base64.b64encode(audio_data).decode(),
|
| 182 |
+
"format": "mp3"
|
| 183 |
+
})
|
| 184 |
+
|
| 185 |
+
async def handle_text_message(websocket: WebSocket, data: dict, session_data: dict,
|
| 186 |
+
use_hybrid: bool, config: dict, knowledge_base: str, graph=None):
|
| 187 |
+
"""Handle text message with hybrid LLM"""
|
| 188 |
+
|
| 189 |
+
user_message = data["message"]
|
| 190 |
+
logger.info(f"๐ฌ Received text message: {user_message}")
|
| 191 |
+
|
| 192 |
+
# Send acknowledgment
|
| 193 |
+
await websocket.send_json({
|
| 194 |
+
"type": "message_received",
|
| 195 |
+
"message": "Processing your message..."
|
| 196 |
+
})
|
| 197 |
+
|
| 198 |
+
try:
|
| 199 |
+
if use_hybrid:
|
| 200 |
+
# Use hybrid LLM service
|
| 201 |
+
response_text, provider_used = await get_hybrid_response(
|
| 202 |
+
user_message, session_data["context"], config, knowledge_base
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
# Use traditional graph approach
|
| 206 |
+
session_data["messages"].append(HumanMessage(content=user_message))
|
| 207 |
+
result = await graph.ainvoke({"messages": session_data["messages"]}, config)
|
| 208 |
+
response_text = result["messages"][-1].content
|
| 209 |
+
provider_used = "traditional"
|
| 210 |
+
|
| 211 |
+
# Handle scenario analysis images
|
| 212 |
+
if "SCENARIO_ANALYSIS_IMAGE:" in response_text:
|
| 213 |
+
await handle_scenario_response(websocket, response_text, provider_used)
|
| 214 |
+
else:
|
| 215 |
+
await send_text_response(websocket, response_text, provider_used, session_data)
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
logger.error(f"โ Error processing text message: {e}")
|
| 219 |
+
await websocket.send_json({
|
| 220 |
+
"type": "error",
|
| 221 |
+
"message": f"Error processing your message: {str(e)}"
|
| 222 |
+
})
|
| 223 |
+
|
| 224 |
+
async def handle_voice_message(websocket: WebSocket, data: dict, session_data: dict,
|
| 225 |
+
use_hybrid: bool, config: dict, knowledge_base: str, graph=None):
|
| 226 |
+
"""Handle voice message with ASR and TTS"""
|
| 227 |
+
|
| 228 |
+
if not voice_service.is_voice_enabled():
|
| 229 |
+
await websocket.send_json({
|
| 230 |
+
"type": "error",
|
| 231 |
+
"message": "Voice features are not enabled"
|
| 232 |
+
})
|
| 233 |
+
return
|
| 234 |
+
|
| 235 |
+
try:
|
| 236 |
+
# Decode audio data
|
| 237 |
+
audio_data = base64.b64decode(data["audio_data"])
|
| 238 |
+
|
| 239 |
+
# Save to temporary file
|
| 240 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
| 241 |
+
temp_file.write(audio_data)
|
| 242 |
+
temp_file_path = temp_file.name
|
| 243 |
+
|
| 244 |
+
# Convert speech to text
|
| 245 |
+
transcribed_text = await voice_service.speech_to_text(temp_file_path)
|
| 246 |
+
|
| 247 |
+
# Clean up temp file
|
| 248 |
+
Path(temp_file_path).unlink()
|
| 249 |
+
|
| 250 |
+
if not transcribed_text:
|
| 251 |
+
await websocket.send_json({
|
| 252 |
+
"type": "error",
|
| 253 |
+
"message": "Could not transcribe audio"
|
| 254 |
+
})
|
| 255 |
+
return
|
| 256 |
+
|
| 257 |
+
logger.info(f"๐ค Transcribed: {transcribed_text}")
|
| 258 |
+
|
| 259 |
+
# Send transcription
|
| 260 |
+
await websocket.send_json({
|
| 261 |
+
"type": "transcription",
|
| 262 |
+
"text": transcribed_text
|
| 263 |
+
})
|
| 264 |
+
|
| 265 |
+
# Process as text message
|
| 266 |
+
if use_hybrid:
|
| 267 |
+
response_text, provider_used = await get_hybrid_response(
|
| 268 |
+
transcribed_text, session_data["context"], config, knowledge_base
|
| 269 |
+
)
|
| 270 |
+
else:
|
| 271 |
+
session_data["messages"].append(HumanMessage(content=transcribed_text))
|
| 272 |
+
result = await graph.ainvoke({"messages": session_data["messages"]}, config)
|
| 273 |
+
response_text = result["messages"][-1].content
|
| 274 |
+
provider_used = "traditional"
|
| 275 |
+
|
| 276 |
+
# Send text response
|
| 277 |
+
await send_text_response(websocket, response_text, provider_used, session_data)
|
| 278 |
+
|
| 279 |
+
# Send voice response if enabled
|
| 280 |
+
if session_data["user_preferences"]["response_mode"] in ["voice", "both"]:
|
| 281 |
+
voice_text = voice_service.create_voice_response_with_guidance(
|
| 282 |
+
response_text,
|
| 283 |
+
suggested_resources=["Government portal", "Local offices"],
|
| 284 |
+
redirect_info="contact your local government office for personalized assistance"
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
audio_response = await voice_service.text_to_speech(
|
| 288 |
+
voice_text,
|
| 289 |
+
session_data["user_preferences"]["preferred_voice"]
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
if audio_response:
|
| 293 |
+
await websocket.send_json({
|
| 294 |
+
"type": "audio_response",
|
| 295 |
+
"audio_data": base64.b64encode(audio_response).decode(),
|
| 296 |
+
"format": "mp3"
|
| 297 |
+
})
|
| 298 |
+
|
| 299 |
+
except Exception as e:
|
| 300 |
+
logger.error(f"โ Error processing voice message: {e}")
|
| 301 |
+
await websocket.send_json({
|
| 302 |
+
"type": "error",
|
| 303 |
+
"message": f"Error processing voice message: {str(e)}"
|
| 304 |
+
})
|
| 305 |
+
|
| 306 |
+
async def get_hybrid_response(user_message: str, context: str, config: dict, knowledge_base: str):
|
| 307 |
+
"""Get response using hybrid LLM with document search"""
|
| 308 |
+
|
| 309 |
+
# Search for relevant documents
|
| 310 |
+
try:
|
| 311 |
+
search_results = await search_government_docs.ainvoke(
|
| 312 |
+
{"query": user_message},
|
| 313 |
+
config=config
|
| 314 |
+
)
|
| 315 |
+
context = search_results if search_results else context
|
| 316 |
+
except:
|
| 317 |
+
logger.warning("Document search failed, using existing context")
|
| 318 |
+
|
| 319 |
+
# Get hybrid LLM response
|
| 320 |
+
response_text = await hybrid_llm_service.get_response(
|
| 321 |
+
user_message,
|
| 322 |
+
context=context,
|
| 323 |
+
system_prompt="""You are a helpful government document assistant. Provide accurate, helpful responses based on the context provided. When appropriate, suggest additional resources or redirect users to relevant departments for more assistance."""
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# Determine which provider was used
|
| 327 |
+
complexity = hybrid_llm_service.determine_task_complexity(user_message, context)
|
| 328 |
+
provider_used = hybrid_llm_service.choose_llm_provider(complexity)
|
| 329 |
+
|
| 330 |
+
return response_text, provider_used
|
| 331 |
+
|
| 332 |
+
async def send_text_response(websocket: WebSocket, response_text: str, provider_used: str, session_data: dict):
|
| 333 |
+
"""Send text response to client"""
|
| 334 |
+
|
| 335 |
+
await websocket.send_json({
|
| 336 |
+
"type": "message_response",
|
| 337 |
+
"message": response_text,
|
| 338 |
+
"provider_used": provider_used,
|
| 339 |
+
"timestamp": asyncio.get_event_loop().time()
|
| 340 |
+
})
|
| 341 |
+
|
| 342 |
+
# Update session context
|
| 343 |
+
session_data["context"] = response_text[-1000:] # Keep last 1000 chars as context
|
| 344 |
+
|
| 345 |
+
async def handle_scenario_response(websocket: WebSocket, response_text: str, provider_used: str):
|
| 346 |
+
"""Handle scenario analysis response with images"""
|
| 347 |
+
|
| 348 |
+
parts = response_text.split("SCENARIO_ANALYSIS_IMAGE:")
|
| 349 |
+
text_part = parts[0].strip()
|
| 350 |
+
|
| 351 |
+
# Send text part
|
| 352 |
+
if text_part:
|
| 353 |
+
await websocket.send_json({
|
| 354 |
+
"type": "message_response",
|
| 355 |
+
"message": text_part,
|
| 356 |
+
"provider_used": provider_used
|
| 357 |
+
})
|
| 358 |
+
|
| 359 |
+
# Send image parts
|
| 360 |
+
for i, part in enumerate(parts[1:], 1):
|
| 361 |
+
try:
|
| 362 |
+
image_data = part.strip()
|
| 363 |
+
await websocket.send_json({
|
| 364 |
+
"type": "scenario_image",
|
| 365 |
+
"image_data": image_data,
|
| 366 |
+
"image_index": i,
|
| 367 |
+
"chart_type": "analysis"
|
| 368 |
+
})
|
| 369 |
+
except Exception as e:
|
| 370 |
+
logger.error(f"Error sending scenario image {i}: {e}")
|
| 371 |
+
|
| 372 |
+
async def handle_preferences_update(websocket: WebSocket, data: dict, session_data: dict):
|
| 373 |
+
"""Handle user preferences update"""
|
| 374 |
+
|
| 375 |
+
try:
|
| 376 |
+
session_data["user_preferences"].update(data["preferences"])
|
| 377 |
+
|
| 378 |
+
await websocket.send_json({
|
| 379 |
+
"type": "preferences_updated",
|
| 380 |
+
"preferences": session_data["user_preferences"]
|
| 381 |
+
})
|
| 382 |
+
|
| 383 |
+
logger.info(f"๐ง Updated user preferences: {session_data['user_preferences']}")
|
| 384 |
+
|
| 385 |
+
except Exception as e:
|
| 386 |
+
logger.error(f"โ Error updating preferences: {e}")
|
| 387 |
+
await websocket.send_json({
|
| 388 |
+
"type": "error",
|
| 389 |
+
"message": f"Error updating preferences: {str(e)}"
|
| 390 |
+
})
|
| 391 |
+
|
| 392 |
+
# Keep the original function for backward compatibility
|
| 393 |
+
async def handle_websocket_connection(websocket: WebSocket):
|
| 394 |
+
"""Original websocket handler for backward compatibility"""
|
| 395 |
+
await handle_enhanced_websocket_connection(websocket)
|
hybrid_llm_service.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid LLM Service that intelligently routes between Groq and Gemini APIs
|
| 3 |
+
based on task complexity and user requirements.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import asyncio
|
| 8 |
+
from enum import Enum
|
| 9 |
+
from typing import Dict, Any, Optional
|
| 10 |
+
import logging
|
| 11 |
+
from langchain_groq import ChatGroq
|
| 12 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 13 |
+
from langchain_core.messages import HumanMessage, SystemMessage
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
class TaskComplexity(Enum):
|
| 18 |
+
SIMPLE = "simple"
|
| 19 |
+
COMPLEX = "complex"
|
| 20 |
+
|
| 21 |
+
class LLMProvider(Enum):
|
| 22 |
+
GROQ = "groq"
|
| 23 |
+
GEMINI = "gemini"
|
| 24 |
+
|
| 25 |
+
class HybridLLMService:
|
| 26 |
+
def __init__(self):
|
| 27 |
+
# Initialize Groq (Primary)
|
| 28 |
+
self.groq_api_key = os.getenv("GROQ_API_KEY")
|
| 29 |
+
self.groq_model = os.getenv("GROQ_MODEL", "llama-3.1-70b-versatile")
|
| 30 |
+
|
| 31 |
+
if self.groq_api_key:
|
| 32 |
+
self.groq_llm = ChatGroq(
|
| 33 |
+
groq_api_key=self.groq_api_key,
|
| 34 |
+
model_name=self.groq_model,
|
| 35 |
+
temperature=0.7
|
| 36 |
+
)
|
| 37 |
+
logger.info(f"โ
Groq LLM initialized: {self.groq_model}")
|
| 38 |
+
else:
|
| 39 |
+
self.groq_llm = None
|
| 40 |
+
logger.warning("โ ๏ธ Groq API key not found")
|
| 41 |
+
|
| 42 |
+
# Initialize Gemini (Secondary/Fallback)
|
| 43 |
+
self.google_api_key = os.getenv("GOOGLE_API_KEY")
|
| 44 |
+
self.gemini_model = os.getenv("GEMINI_MODEL", "gemini-1.5-flash") # Use flash model for free tier
|
| 45 |
+
|
| 46 |
+
if self.google_api_key:
|
| 47 |
+
try:
|
| 48 |
+
self.gemini_llm = ChatGoogleGenerativeAI(
|
| 49 |
+
model=self.gemini_model,
|
| 50 |
+
google_api_key=self.google_api_key,
|
| 51 |
+
temperature=0.7
|
| 52 |
+
)
|
| 53 |
+
logger.info(f"โ
Gemini LLM initialized: {self.gemini_model}")
|
| 54 |
+
except Exception as e:
|
| 55 |
+
self.gemini_llm = None
|
| 56 |
+
logger.warning(f"โ ๏ธ Gemini initialization failed: {e}")
|
| 57 |
+
else:
|
| 58 |
+
self.gemini_llm = None
|
| 59 |
+
logger.warning("โ ๏ธ Google API key not found")
|
| 60 |
+
|
| 61 |
+
# Hybrid configuration
|
| 62 |
+
self.use_hybrid = os.getenv("USE_HYBRID_LLM", "true").lower() == "true"
|
| 63 |
+
self.primary_provider = LLMProvider.GROQ # Always use Groq as primary
|
| 64 |
+
|
| 65 |
+
logger.info(f"๐ค Hybrid LLM Service initialized (Primary: {self.primary_provider.value})")
|
| 66 |
+
|
| 67 |
+
def analyze_task_complexity(self, message: str) -> TaskComplexity:
|
| 68 |
+
"""Analyze if a task requires complex reasoning or simple response"""
|
| 69 |
+
complex_keywords = [
|
| 70 |
+
'analyze', 'compare', 'evaluate', 'scenario', 'chart', 'graph',
|
| 71 |
+
'visualization', 'complex', 'detailed analysis', 'multi-step',
|
| 72 |
+
'comprehensive', 'in-depth', 'elaborate', 'breakdown'
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
simple_keywords = [
|
| 76 |
+
'what is', 'who is', 'when', 'where', 'how to', 'define',
|
| 77 |
+
'explain', 'tell me', 'show me', 'list', 'summary'
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
message_lower = message.lower()
|
| 81 |
+
|
| 82 |
+
# Count complex vs simple indicators
|
| 83 |
+
complex_score = sum(1 for keyword in complex_keywords if keyword in message_lower)
|
| 84 |
+
simple_score = sum(1 for keyword in simple_keywords if keyword in message_lower)
|
| 85 |
+
|
| 86 |
+
# If message is very long (>200 chars) or has complex keywords, use complex
|
| 87 |
+
if len(message) > 200 or complex_score > simple_score:
|
| 88 |
+
return TaskComplexity.COMPLEX
|
| 89 |
+
|
| 90 |
+
return TaskComplexity.SIMPLE
|
| 91 |
+
|
| 92 |
+
def choose_llm_provider(self, message: str) -> LLMProvider:
|
| 93 |
+
"""Choose the best LLM provider based on task complexity and availability"""
|
| 94 |
+
|
| 95 |
+
# If hybrid is disabled, always use primary (Groq)
|
| 96 |
+
if not self.use_hybrid:
|
| 97 |
+
return LLMProvider.GROQ if self.groq_llm else LLMProvider.GEMINI
|
| 98 |
+
|
| 99 |
+
# Always prefer Groq for better speed and reliability
|
| 100 |
+
if self.groq_llm:
|
| 101 |
+
return LLMProvider.GROQ
|
| 102 |
+
|
| 103 |
+
# Fallback to Gemini only if Groq is not available
|
| 104 |
+
if self.gemini_llm:
|
| 105 |
+
return LLMProvider.GEMINI
|
| 106 |
+
|
| 107 |
+
# If neither is available, return Groq (will handle error gracefully)
|
| 108 |
+
return LLMProvider.GROQ
|
| 109 |
+
|
| 110 |
+
async def get_response(self, message: str, context: str = "") -> str:
|
| 111 |
+
"""Get response from the chosen LLM provider"""
|
| 112 |
+
provider = self.choose_llm_provider(message)
|
| 113 |
+
complexity = self.analyze_task_complexity(message)
|
| 114 |
+
|
| 115 |
+
logger.info(f"๐ฏ Using {provider.value} for {complexity.value} task")
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
if provider == LLMProvider.GROQ and self.groq_llm:
|
| 119 |
+
return await self._get_groq_response(message, context)
|
| 120 |
+
elif provider == LLMProvider.GEMINI and self.gemini_llm:
|
| 121 |
+
return await self._get_gemini_response(message, context)
|
| 122 |
+
else:
|
| 123 |
+
# Fallback logic
|
| 124 |
+
if self.groq_llm:
|
| 125 |
+
logger.info("๐ Falling back to Groq")
|
| 126 |
+
return await self._get_groq_response(message, context)
|
| 127 |
+
elif self.gemini_llm:
|
| 128 |
+
logger.info("๐ Falling back to Gemini")
|
| 129 |
+
return await self._get_gemini_response(message, context)
|
| 130 |
+
else:
|
| 131 |
+
return "I apologize, but no AI providers are currently available. Please check your API keys."
|
| 132 |
+
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.error(f"โ Error with {provider.value}: {e}")
|
| 135 |
+
|
| 136 |
+
# Try fallback provider
|
| 137 |
+
if provider == LLMProvider.GROQ and self.gemini_llm:
|
| 138 |
+
logger.info("๐ Groq failed, trying Gemini")
|
| 139 |
+
try:
|
| 140 |
+
return await self._get_gemini_response(message, context)
|
| 141 |
+
except Exception as gemini_error:
|
| 142 |
+
logger.error(f"โ Gemini fallback also failed: {gemini_error}")
|
| 143 |
+
return f"I apologize, but I'm experiencing technical difficulties. Both AI providers are currently unavailable."
|
| 144 |
+
|
| 145 |
+
elif provider == LLMProvider.GEMINI and self.groq_llm:
|
| 146 |
+
logger.info("๐ Gemini failed, trying Groq")
|
| 147 |
+
try:
|
| 148 |
+
return await self._get_groq_response(message, context)
|
| 149 |
+
except Exception as groq_error:
|
| 150 |
+
logger.error(f"โ Groq fallback also failed: {groq_error}")
|
| 151 |
+
return f"I apologize, but I'm experiencing technical difficulties. Both AI providers are currently unavailable."
|
| 152 |
+
|
| 153 |
+
return f"I apologize, but I encountered an error: {str(e)}"
|
| 154 |
+
|
| 155 |
+
async def _get_groq_response(self, message: str, context: str = "") -> str:
|
| 156 |
+
"""Get response from Groq LLM"""
|
| 157 |
+
system_prompt = """You are a helpful AI assistant specializing in government policies and procedures.
|
| 158 |
+
You have access to government documents and can provide accurate information based on them.
|
| 159 |
+
Provide clear, concise, and helpful responses."""
|
| 160 |
+
|
| 161 |
+
if context:
|
| 162 |
+
system_prompt += f"\n\nRelevant context from documents:\n{context}"
|
| 163 |
+
|
| 164 |
+
messages = [
|
| 165 |
+
SystemMessage(content=system_prompt),
|
| 166 |
+
HumanMessage(content=message)
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
response = await self.groq_llm.ainvoke(messages)
|
| 170 |
+
return response.content
|
| 171 |
+
|
| 172 |
+
async def _get_gemini_response(self, message: str, context: str = "") -> str:
|
| 173 |
+
"""Get response from Gemini LLM"""
|
| 174 |
+
system_prompt = """You are a helpful AI assistant specializing in government policies and procedures.
|
| 175 |
+
You have access to government documents and can provide accurate information based on them.
|
| 176 |
+
Provide detailed, analytical responses when needed."""
|
| 177 |
+
|
| 178 |
+
if context:
|
| 179 |
+
system_prompt += f"\n\nRelevant context from documents:\n{context}"
|
| 180 |
+
|
| 181 |
+
messages = [
|
| 182 |
+
SystemMessage(content=system_prompt),
|
| 183 |
+
HumanMessage(content=message)
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
response = await self.gemini_llm.ainvoke(messages)
|
| 187 |
+
return response.content
|
| 188 |
+
|
| 189 |
+
async def get_streaming_response(self, message: str, context: str = ""):
|
| 190 |
+
"""Get streaming response from the chosen LLM provider"""
|
| 191 |
+
provider = self.choose_llm_provider(message)
|
| 192 |
+
|
| 193 |
+
try:
|
| 194 |
+
if provider == LLMProvider.GROQ and self.groq_llm:
|
| 195 |
+
async for chunk in self._get_groq_streaming_response(message, context):
|
| 196 |
+
yield chunk
|
| 197 |
+
elif provider == LLMProvider.GEMINI and self.gemini_llm:
|
| 198 |
+
async for chunk in self._get_gemini_streaming_response(message, context):
|
| 199 |
+
yield chunk
|
| 200 |
+
else:
|
| 201 |
+
# Fallback to available provider
|
| 202 |
+
if self.groq_llm:
|
| 203 |
+
async for chunk in self._get_groq_streaming_response(message, context):
|
| 204 |
+
yield chunk
|
| 205 |
+
else:
|
| 206 |
+
yield "No AI providers are currently available."
|
| 207 |
+
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.error(f"โ Streaming error with {provider.value}: {e}")
|
| 210 |
+
|
| 211 |
+
# Try fallback
|
| 212 |
+
if provider == LLMProvider.GROQ and self.gemini_llm:
|
| 213 |
+
try:
|
| 214 |
+
async for chunk in self._get_gemini_streaming_response(message, context):
|
| 215 |
+
yield chunk
|
| 216 |
+
except:
|
| 217 |
+
yield f"I apologize, but I'm experiencing technical difficulties."
|
| 218 |
+
elif provider == LLMProvider.GEMINI and self.groq_llm:
|
| 219 |
+
try:
|
| 220 |
+
async for chunk in self._get_groq_streaming_response(message, context):
|
| 221 |
+
yield chunk
|
| 222 |
+
except:
|
| 223 |
+
yield f"I apologize, but I'm experiencing technical difficulties."
|
| 224 |
+
else:
|
| 225 |
+
yield f"Error: {str(e)}"
|
| 226 |
+
|
| 227 |
+
async def _get_groq_streaming_response(self, message: str, context: str = ""):
|
| 228 |
+
"""Get streaming response from Groq"""
|
| 229 |
+
system_prompt = """You are a helpful AI assistant specializing in government policies and procedures."""
|
| 230 |
+
|
| 231 |
+
if context:
|
| 232 |
+
system_prompt += f"\n\nRelevant context:\n{context}"
|
| 233 |
+
|
| 234 |
+
messages = [
|
| 235 |
+
SystemMessage(content=system_prompt),
|
| 236 |
+
HumanMessage(content=message)
|
| 237 |
+
]
|
| 238 |
+
|
| 239 |
+
# Groq streaming
|
| 240 |
+
async for chunk in self.groq_llm.astream(messages):
|
| 241 |
+
if chunk.content:
|
| 242 |
+
yield chunk.content
|
| 243 |
+
await asyncio.sleep(0.01)
|
| 244 |
+
|
| 245 |
+
async def _get_gemini_streaming_response(self, message: str, context: str = ""):
|
| 246 |
+
"""Get streaming response from Gemini"""
|
| 247 |
+
system_prompt = """You are a helpful AI assistant specializing in government policies and procedures."""
|
| 248 |
+
|
| 249 |
+
if context:
|
| 250 |
+
system_prompt += f"\n\nRelevant context:\n{context}"
|
| 251 |
+
|
| 252 |
+
messages = [
|
| 253 |
+
SystemMessage(content=system_prompt),
|
| 254 |
+
HumanMessage(content=message)
|
| 255 |
+
]
|
| 256 |
+
|
| 257 |
+
# Gemini streaming
|
| 258 |
+
async for chunk in self.gemini_llm.astream(messages):
|
| 259 |
+
if chunk.content:
|
| 260 |
+
yield chunk.content
|
| 261 |
+
await asyncio.sleep(0.01)
|
lancedb_service.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import lancedb
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 4 |
+
from config import EMBEDDING_MODEL_NAME, LANCEDB_PATH
|
| 5 |
+
from typing import List, Dict, Any, Optional
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import uuid
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger("voicebot")
|
| 13 |
+
|
| 14 |
+
# Lazy load embedding model to reduce startup time and memory usage
|
| 15 |
+
embedding_model = None
|
| 16 |
+
|
| 17 |
+
def get_embedding_model():
|
| 18 |
+
"""Lazy load the embedding model"""
|
| 19 |
+
global embedding_model
|
| 20 |
+
if embedding_model is None:
|
| 21 |
+
logger.info(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
|
| 22 |
+
embedding_model = HuggingFaceEmbeddings(
|
| 23 |
+
model_name=EMBEDDING_MODEL_NAME,
|
| 24 |
+
model_kwargs={
|
| 25 |
+
"device": "cpu",
|
| 26 |
+
"trust_remote_code": True
|
| 27 |
+
},
|
| 28 |
+
encode_kwargs={
|
| 29 |
+
"normalize_embeddings": True
|
| 30 |
+
}
|
| 31 |
+
)
|
| 32 |
+
return embedding_model
|
| 33 |
+
|
| 34 |
+
class LanceDBService:
|
| 35 |
+
def __init__(self):
|
| 36 |
+
self.db_path = LANCEDB_PATH
|
| 37 |
+
self.db = None
|
| 38 |
+
self.embedding_model = get_embedding_model()
|
| 39 |
+
self._initialize_db()
|
| 40 |
+
|
| 41 |
+
def _initialize_db(self):
|
| 42 |
+
"""Initialize LanceDB connection and create tables if they don't exist"""
|
| 43 |
+
try:
|
| 44 |
+
os.makedirs(self.db_path, exist_ok=True)
|
| 45 |
+
self.db = lancedb.connect(self.db_path)
|
| 46 |
+
|
| 47 |
+
# Initialize tables
|
| 48 |
+
self._init_documents_table()
|
| 49 |
+
self._init_personas_table()
|
| 50 |
+
self._init_mcp_servers_table()
|
| 51 |
+
self._init_sessions_table()
|
| 52 |
+
|
| 53 |
+
logger.info("โ
LanceDB initialized successfully")
|
| 54 |
+
except Exception as e:
|
| 55 |
+
logger.error(f"โ Error initializing LanceDB: {e}")
|
| 56 |
+
raise
|
| 57 |
+
|
| 58 |
+
def _init_documents_table(self):
|
| 59 |
+
"""Initialize documents table for vector storage"""
|
| 60 |
+
try:
|
| 61 |
+
if "documents" not in self.db.table_names():
|
| 62 |
+
# Create sample data to define schema
|
| 63 |
+
sample_data = pd.DataFrame({
|
| 64 |
+
"id": [str(uuid.uuid4())],
|
| 65 |
+
"content": ["sample"],
|
| 66 |
+
"metadata": [json.dumps({})],
|
| 67 |
+
"user_id": ["sample"],
|
| 68 |
+
"knowledge_base": ["sample"],
|
| 69 |
+
"filename": ["sample"],
|
| 70 |
+
"upload_date": [datetime.now().isoformat()],
|
| 71 |
+
"vector": [get_embedding_model().embed_query("sample")]
|
| 72 |
+
})
|
| 73 |
+
self.db.create_table("documents", sample_data)
|
| 74 |
+
# Delete sample data
|
| 75 |
+
tbl = self.db.open_table("documents")
|
| 76 |
+
tbl.delete("id = 'sample'")
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.error(f"โ Error initializing documents table: {e}")
|
| 79 |
+
|
| 80 |
+
def _init_personas_table(self):
|
| 81 |
+
"""Initialize personas table"""
|
| 82 |
+
try:
|
| 83 |
+
if "personas" not in self.db.table_names():
|
| 84 |
+
sample_data = pd.DataFrame({
|
| 85 |
+
"id": [str(uuid.uuid4())],
|
| 86 |
+
"user_id": ["sample"],
|
| 87 |
+
"name": ["sample"],
|
| 88 |
+
"description": ["sample"],
|
| 89 |
+
"icon": ["sample"],
|
| 90 |
+
"custom_prompt": ["sample"],
|
| 91 |
+
"knowledge_base": ["none"],
|
| 92 |
+
"language": ["en"],
|
| 93 |
+
"created_at": [datetime.now().isoformat()],
|
| 94 |
+
"updated_at": [datetime.now().isoformat()]
|
| 95 |
+
})
|
| 96 |
+
self.db.create_table("personas", sample_data)
|
| 97 |
+
tbl = self.db.open_table("personas")
|
| 98 |
+
tbl.delete("id = 'sample'")
|
| 99 |
+
except Exception as e:
|
| 100 |
+
logger.error(f"โ Error initializing personas table: {e}")
|
| 101 |
+
|
| 102 |
+
def _init_mcp_servers_table(self):
|
| 103 |
+
"""Initialize MCP servers table"""
|
| 104 |
+
try:
|
| 105 |
+
if "mcp_servers" not in self.db.table_names():
|
| 106 |
+
sample_data = pd.DataFrame({
|
| 107 |
+
"id": [str(uuid.uuid4())],
|
| 108 |
+
"user_id": ["sample"],
|
| 109 |
+
"name": ["sample"],
|
| 110 |
+
"url": ["sample"],
|
| 111 |
+
"bearer_token": ["sample"],
|
| 112 |
+
"created_at": [datetime.now().isoformat()]
|
| 113 |
+
})
|
| 114 |
+
self.db.create_table("mcp_servers", sample_data)
|
| 115 |
+
tbl = self.db.open_table("mcp_servers")
|
| 116 |
+
tbl.delete("id = 'sample'")
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"โ Error initializing mcp_servers table: {e}")
|
| 119 |
+
|
| 120 |
+
def _init_sessions_table(self):
|
| 121 |
+
"""Initialize sessions table"""
|
| 122 |
+
try:
|
| 123 |
+
if "sessions" not in self.db.table_names():
|
| 124 |
+
sample_data = pd.DataFrame({
|
| 125 |
+
"id": [str(uuid.uuid4())],
|
| 126 |
+
"user_id": ["sample"],
|
| 127 |
+
"persona_id": ["sample"],
|
| 128 |
+
"persona_source": ["sample"],
|
| 129 |
+
"session_summary": ["sample"],
|
| 130 |
+
"created_at": [datetime.now().isoformat()],
|
| 131 |
+
"updated_at": [datetime.now().isoformat()]
|
| 132 |
+
})
|
| 133 |
+
self.db.create_table("sessions", sample_data)
|
| 134 |
+
tbl = self.db.open_table("sessions")
|
| 135 |
+
tbl.delete("id = 'sample'")
|
| 136 |
+
except Exception as e:
|
| 137 |
+
logger.error(f"โ Error initializing sessions table: {e}")
|
| 138 |
+
|
| 139 |
+
async def add_documents(self, docs, user_id: str, knowledge_base: str, filename: str):
|
| 140 |
+
"""Add documents to LanceDB vector store"""
|
| 141 |
+
try:
|
| 142 |
+
documents_to_insert = []
|
| 143 |
+
for doc in docs:
|
| 144 |
+
embedding = self.embedding_model.embed_query(doc.page_content)
|
| 145 |
+
|
| 146 |
+
doc_data = {
|
| 147 |
+
"id": str(uuid.uuid4()),
|
| 148 |
+
"content": doc.page_content,
|
| 149 |
+
"metadata": json.dumps(doc.metadata),
|
| 150 |
+
"user_id": user_id,
|
| 151 |
+
"knowledge_base": knowledge_base,
|
| 152 |
+
"filename": filename,
|
| 153 |
+
"upload_date": datetime.now().isoformat(),
|
| 154 |
+
"vector": embedding
|
| 155 |
+
}
|
| 156 |
+
documents_to_insert.append(doc_data)
|
| 157 |
+
|
| 158 |
+
# Insert documents
|
| 159 |
+
tbl = self.db.open_table("documents")
|
| 160 |
+
df = pd.DataFrame(documents_to_insert)
|
| 161 |
+
tbl.add(df)
|
| 162 |
+
|
| 163 |
+
logger.info(f"โ
Added {len(docs)} documents to LanceDB")
|
| 164 |
+
return len(docs)
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error(f"โ Error adding documents to LanceDB: {e}")
|
| 168 |
+
raise e
|
| 169 |
+
|
| 170 |
+
async def similarity_search(self, query: str, user_id: str, knowledge_base: str, k: int = 4):
|
| 171 |
+
"""Search for similar documents"""
|
| 172 |
+
try:
|
| 173 |
+
query_embedding = self.embedding_model.embed_query(query)
|
| 174 |
+
|
| 175 |
+
tbl = self.db.open_table("documents")
|
| 176 |
+
|
| 177 |
+
# Search with filters
|
| 178 |
+
results = tbl.search(query_embedding)\
|
| 179 |
+
.where(f"user_id = '{user_id}' AND knowledge_base = '{knowledge_base}'")\
|
| 180 |
+
.limit(k)\
|
| 181 |
+
.to_list()
|
| 182 |
+
|
| 183 |
+
docs = []
|
| 184 |
+
for result in results:
|
| 185 |
+
docs.append(type('Document', (), {
|
| 186 |
+
'page_content': result['content'],
|
| 187 |
+
'metadata': json.loads(result['metadata']) if result['metadata'] else {}
|
| 188 |
+
})())
|
| 189 |
+
|
| 190 |
+
return docs
|
| 191 |
+
|
| 192 |
+
except Exception as e:
|
| 193 |
+
logger.error(f"โ Error searching LanceDB: {e}")
|
| 194 |
+
return []
|
| 195 |
+
|
| 196 |
+
async def get_user_knowledge_bases(self, user_id: str) -> List[str]:
|
| 197 |
+
"""Get all knowledge bases for a user"""
|
| 198 |
+
try:
|
| 199 |
+
tbl = self.db.open_table("documents")
|
| 200 |
+
df = tbl.search().where(f"user_id = '{user_id}'").to_pandas()
|
| 201 |
+
|
| 202 |
+
if df.empty:
|
| 203 |
+
return []
|
| 204 |
+
|
| 205 |
+
knowledge_bases = df['knowledge_base'].unique().tolist()
|
| 206 |
+
return [kb for kb in knowledge_bases if kb and kb != "none"]
|
| 207 |
+
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.error(f"โ Error fetching knowledge bases: {e}")
|
| 210 |
+
return []
|
| 211 |
+
|
| 212 |
+
async def get_kb_documents(self, user_id: str, kb_name: str):
|
| 213 |
+
"""Get all documents in a knowledge base"""
|
| 214 |
+
try:
|
| 215 |
+
tbl = self.db.open_table("documents")
|
| 216 |
+
df = tbl.search().where(f"user_id = '{user_id}' AND knowledge_base = '{kb_name}'").to_pandas()
|
| 217 |
+
|
| 218 |
+
documents = []
|
| 219 |
+
for _, row in df.iterrows():
|
| 220 |
+
documents.append({
|
| 221 |
+
"id": row['id'],
|
| 222 |
+
"filename": row['filename'],
|
| 223 |
+
"knowledge_base": row['knowledge_base'],
|
| 224 |
+
"upload_date": row['upload_date']
|
| 225 |
+
})
|
| 226 |
+
|
| 227 |
+
return documents
|
| 228 |
+
|
| 229 |
+
except Exception as e:
|
| 230 |
+
logger.error(f"โ Error fetching documents: {e}")
|
| 231 |
+
return []
|
| 232 |
+
|
| 233 |
+
async def delete_document_from_kb(self, user_id: str, kb_name: str, filename: str):
|
| 234 |
+
"""Delete a document from knowledge base"""
|
| 235 |
+
try:
|
| 236 |
+
tbl = self.db.open_table("documents")
|
| 237 |
+
tbl.delete(f"user_id = '{user_id}' AND knowledge_base = '{kb_name}' AND filename = '{filename}'")
|
| 238 |
+
return True
|
| 239 |
+
|
| 240 |
+
except Exception as e:
|
| 241 |
+
logger.error(f"โ Error deleting document: {e}")
|
| 242 |
+
return False
|
| 243 |
+
|
| 244 |
+
# Persona management methods
|
| 245 |
+
async def insert_persona(self, name: str, description: str, icon: str, custom_prompt: str, user_id: str):
|
| 246 |
+
"""Insert a new persona"""
|
| 247 |
+
try:
|
| 248 |
+
persona_data = {
|
| 249 |
+
"id": str(uuid.uuid4()),
|
| 250 |
+
"user_id": user_id,
|
| 251 |
+
"name": name,
|
| 252 |
+
"description": description,
|
| 253 |
+
"icon": icon,
|
| 254 |
+
"custom_prompt": custom_prompt,
|
| 255 |
+
"knowledge_base": "none",
|
| 256 |
+
"language": "en",
|
| 257 |
+
"created_at": datetime.now().isoformat(),
|
| 258 |
+
"updated_at": datetime.now().isoformat()
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
tbl = self.db.open_table("personas")
|
| 262 |
+
df = pd.DataFrame([persona_data])
|
| 263 |
+
tbl.add(df)
|
| 264 |
+
|
| 265 |
+
return persona_data
|
| 266 |
+
|
| 267 |
+
except Exception as e:
|
| 268 |
+
logger.error(f"โ Error inserting persona: {e}")
|
| 269 |
+
raise e
|
| 270 |
+
|
| 271 |
+
async def get_user_personas(self, user_id: str):
|
| 272 |
+
"""Get all personas for a user"""
|
| 273 |
+
try:
|
| 274 |
+
tbl = self.db.open_table("personas")
|
| 275 |
+
df = tbl.search().where(f"user_id = '{user_id}'").to_pandas()
|
| 276 |
+
return df.to_dict('records')
|
| 277 |
+
|
| 278 |
+
except Exception as e:
|
| 279 |
+
logger.error(f"โ Error fetching personas: {e}")
|
| 280 |
+
return []
|
| 281 |
+
|
| 282 |
+
# MCP Server methods
|
| 283 |
+
async def create_mcp_server(self, user_id: str, name: str, url: str, bearer_token: str = None):
|
| 284 |
+
"""Create MCP server entry"""
|
| 285 |
+
try:
|
| 286 |
+
server_data = {
|
| 287 |
+
"id": str(uuid.uuid4()),
|
| 288 |
+
"user_id": user_id,
|
| 289 |
+
"name": name,
|
| 290 |
+
"url": url,
|
| 291 |
+
"bearer_token": bearer_token,
|
| 292 |
+
"created_at": datetime.now().isoformat()
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
tbl = self.db.open_table("mcp_servers")
|
| 296 |
+
df = pd.DataFrame([server_data])
|
| 297 |
+
tbl.add(df)
|
| 298 |
+
|
| 299 |
+
return server_data
|
| 300 |
+
|
| 301 |
+
except Exception as e:
|
| 302 |
+
logger.error(f"โ Error creating MCP server: {e}")
|
| 303 |
+
raise e
|
| 304 |
+
|
| 305 |
+
async def get_mcp_servers_for_user(self, user_id: str):
|
| 306 |
+
"""Get MCP servers for user"""
|
| 307 |
+
try:
|
| 308 |
+
tbl = self.db.open_table("mcp_servers")
|
| 309 |
+
df = tbl.search().where(f"user_id = '{user_id}'").to_pandas()
|
| 310 |
+
return df.to_dict('records')
|
| 311 |
+
|
| 312 |
+
except Exception as e:
|
| 313 |
+
logger.error(f"โ Error fetching MCP servers: {e}")
|
| 314 |
+
return []
|
| 315 |
+
|
| 316 |
+
async def delete_mcp_server(self, user_id: str, server_id: str):
|
| 317 |
+
"""Delete MCP server"""
|
| 318 |
+
try:
|
| 319 |
+
tbl = self.db.open_table("mcp_servers")
|
| 320 |
+
tbl.delete(f"user_id = '{user_id}' AND id = '{server_id}'")
|
| 321 |
+
return True
|
| 322 |
+
|
| 323 |
+
except Exception as e:
|
| 324 |
+
logger.error(f"โ Error deleting MCP server: {e}")
|
| 325 |
+
return False
|
| 326 |
+
|
| 327 |
+
# Session management
|
| 328 |
+
async def upsert_session_summary(self, user_id: str, persona_id: str, persona_source: str, summary: str):
|
| 329 |
+
"""Create or update session summary"""
|
| 330 |
+
try:
|
| 331 |
+
session_data = {
|
| 332 |
+
"id": str(uuid.uuid4()),
|
| 333 |
+
"user_id": user_id,
|
| 334 |
+
"persona_id": persona_id,
|
| 335 |
+
"persona_source": persona_source,
|
| 336 |
+
"session_summary": summary,
|
| 337 |
+
"created_at": datetime.now().isoformat(),
|
| 338 |
+
"updated_at": datetime.now().isoformat()
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
tbl = self.db.open_table("sessions")
|
| 342 |
+
df = pd.DataFrame([session_data])
|
| 343 |
+
tbl.add(df)
|
| 344 |
+
|
| 345 |
+
return session_data
|
| 346 |
+
|
| 347 |
+
except Exception as e:
|
| 348 |
+
logger.error(f"โ Error upserting session: {e}")
|
| 349 |
+
return None
|
| 350 |
+
|
| 351 |
+
async def get_knowledge_bases(self) -> List[str]:
|
| 352 |
+
"""Get all unique knowledge bases"""
|
| 353 |
+
try:
|
| 354 |
+
tbl = self.db.open_table("documents")
|
| 355 |
+
df = tbl.search().to_pandas()
|
| 356 |
+
|
| 357 |
+
if df.empty:
|
| 358 |
+
return []
|
| 359 |
+
|
| 360 |
+
knowledge_bases = df['knowledge_base'].unique().tolist()
|
| 361 |
+
return [kb for kb in knowledge_bases if kb and kb != "none"]
|
| 362 |
+
|
| 363 |
+
except Exception as e:
|
| 364 |
+
logger.error(f"โ Error getting knowledge bases: {e}")
|
| 365 |
+
return []
|
| 366 |
+
|
| 367 |
+
async def get_documents_by_knowledge_base(self, knowledge_base: str) -> List[dict]:
|
| 368 |
+
"""Get list of documents in a specific knowledge base"""
|
| 369 |
+
try:
|
| 370 |
+
tbl = self.db.open_table("documents")
|
| 371 |
+
df = tbl.search().where(f"knowledge_base = '{knowledge_base}'").to_pandas()
|
| 372 |
+
|
| 373 |
+
if df.empty:
|
| 374 |
+
return []
|
| 375 |
+
|
| 376 |
+
# Group by filename and get document info
|
| 377 |
+
documents = []
|
| 378 |
+
for filename in df['filename'].unique():
|
| 379 |
+
file_docs = df[df['filename'] == filename]
|
| 380 |
+
documents.append({
|
| 381 |
+
"filename": filename,
|
| 382 |
+
"knowledge_base": knowledge_base,
|
| 383 |
+
"chunks": len(file_docs),
|
| 384 |
+
"upload_date": file_docs['upload_date'].iloc[0] if not file_docs.empty else None
|
| 385 |
+
})
|
| 386 |
+
|
| 387 |
+
return documents
|
| 388 |
+
|
| 389 |
+
except Exception as e:
|
| 390 |
+
logger.error(f"โ Error getting documents by knowledge base: {e}")
|
| 391 |
+
return []
|
| 392 |
+
|
| 393 |
+
async def delete_document(self, filename: str, knowledge_base: str, user_id: str = None):
|
| 394 |
+
"""Delete a document from the knowledge base"""
|
| 395 |
+
try:
|
| 396 |
+
tbl = self.db.open_table("documents")
|
| 397 |
+
|
| 398 |
+
where_clause = f"filename = '{filename}' AND knowledge_base = '{knowledge_base}'"
|
| 399 |
+
if user_id:
|
| 400 |
+
where_clause += f" AND user_id = '{user_id}'"
|
| 401 |
+
|
| 402 |
+
# Delete the document chunks
|
| 403 |
+
tbl.delete(where_clause)
|
| 404 |
+
|
| 405 |
+
logger.info(f"โ
Deleted document {filename} from knowledge base {knowledge_base}")
|
| 406 |
+
return True
|
| 407 |
+
|
| 408 |
+
except Exception as e:
|
| 409 |
+
logger.error(f"โ Error deleting document: {e}")
|
| 410 |
+
return False
|
| 411 |
+
|
| 412 |
+
async def search_all_knowledge_bases(self, query: str, k: int = 4):
|
| 413 |
+
"""Search across all knowledge bases"""
|
| 414 |
+
try:
|
| 415 |
+
query_embedding = self.embedding_model.embed_query(query)
|
| 416 |
+
|
| 417 |
+
tbl = self.db.open_table("documents")
|
| 418 |
+
|
| 419 |
+
# Search without user filters for system-wide search
|
| 420 |
+
results = tbl.search(query_embedding).limit(k).to_list()
|
| 421 |
+
|
| 422 |
+
docs = []
|
| 423 |
+
for result in results:
|
| 424 |
+
docs.append(type('Document', (), {
|
| 425 |
+
'page_content': result['content'],
|
| 426 |
+
'metadata': json.loads(result['metadata']) if result['metadata'] else {}
|
| 427 |
+
})())
|
| 428 |
+
|
| 429 |
+
return docs
|
| 430 |
+
|
| 431 |
+
except Exception as e:
|
| 432 |
+
logger.error(f"โ Error searching all knowledge bases: {e}")
|
| 433 |
+
return []
|
| 434 |
+
|
| 435 |
+
# Global instance
|
| 436 |
+
lancedb_service = LanceDBService()
|
llm_service.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import StateGraph, END, START
|
| 2 |
+
from langchain_core.messages import SystemMessage
|
| 3 |
+
from typing import TypedDict
|
| 4 |
+
from config import GOOGLE_API_KEY, GEMINI_MODEL, GEMINI_TEMPERATURE
|
| 5 |
+
from rag_service import search_docs, search_government_docs, analyze_scenario
|
| 6 |
+
from langchain_tavily import TavilySearch
|
| 7 |
+
from langgraph.prebuilt import ToolNode, tools_condition
|
| 8 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 9 |
+
from typing import Annotated
|
| 10 |
+
from typing_extensions import TypedDict
|
| 11 |
+
from langgraph.graph.message import add_messages
|
| 12 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 13 |
+
from pydantic import BaseModel, Field, field_validator
|
| 14 |
+
from typing import List, Optional, Literal
|
| 15 |
+
from langchain_core.tools import tool
|
| 16 |
+
import asyncio
|
| 17 |
+
|
| 18 |
+
# Optional MCP client - install with: pip install langchain-mcp-adapters
|
| 19 |
+
try:
|
| 20 |
+
from langchain_mcp_adapters.client import MultiServerMCPClient
|
| 21 |
+
MCP_AVAILABLE = True
|
| 22 |
+
except ImportError:
|
| 23 |
+
MCP_AVAILABLE = False
|
| 24 |
+
|
| 25 |
+
# Optional Tavily search - requires TAVILY_API_KEY environment variable
|
| 26 |
+
try:
|
| 27 |
+
tavily_search = TavilySearch(max_results=4)
|
| 28 |
+
TAVILY_AVAILABLE = True
|
| 29 |
+
except Exception:
|
| 30 |
+
TAVILY_AVAILABLE = False
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@tool
|
| 34 |
+
def search_tool(query: str):
|
| 35 |
+
"""
|
| 36 |
+
Perform an advanced web search using the Tavily Search API with hardcoded options.
|
| 37 |
+
|
| 38 |
+
Parameters:
|
| 39 |
+
----------
|
| 40 |
+
query : str
|
| 41 |
+
The search query string.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
-------
|
| 45 |
+
str
|
| 46 |
+
The search results as a string returned by the Tavily Search API.
|
| 47 |
+
|
| 48 |
+
Raises:
|
| 49 |
+
------
|
| 50 |
+
Exception
|
| 51 |
+
Errors during the search are caught and returned as error strings.
|
| 52 |
+
"""
|
| 53 |
+
if not TAVILY_AVAILABLE:
|
| 54 |
+
return "Web search is not available. Tavily API key is not configured."
|
| 55 |
+
|
| 56 |
+
query_params = {"query": query, "auto_parameters": True}
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
result = tavily_search.invoke(query_params)
|
| 60 |
+
return result
|
| 61 |
+
except Exception as e:
|
| 62 |
+
return f"Error during Tavily search: {str(e)}"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# State definition
|
| 66 |
+
class State(TypedDict):
|
| 67 |
+
# add_messages is known as a reducer, where it does not modify the list but adds messages to it
|
| 68 |
+
messages: Annotated[list, add_messages]
|
| 69 |
+
# messages: Annotated[list[BaseMessage], add_messages]
|
| 70 |
+
# both have same result no need to use BaseMessage
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
async def create_graph(kb_tool: bool, mcp_config: dict):
|
| 74 |
+
if mcp_config and MCP_AVAILABLE:
|
| 75 |
+
server_config = {
|
| 76 |
+
"url": mcp_config["url"],
|
| 77 |
+
"transport": "streamable_http",
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
# Add headers if bearer token exists
|
| 81 |
+
if mcp_config.get("bearerToken"):
|
| 82 |
+
server_config["headers"] = {
|
| 83 |
+
"Authorization": f"Bearer {mcp_config['bearerToken']}"
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
client = MultiServerMCPClient({mcp_config["name"]: server_config})
|
| 87 |
+
mcp_tools = await client.get_tools()
|
| 88 |
+
else:
|
| 89 |
+
mcp_tools = []
|
| 90 |
+
llm = ChatGoogleGenerativeAI(
|
| 91 |
+
model=GEMINI_MODEL,
|
| 92 |
+
google_api_key=GOOGLE_API_KEY,
|
| 93 |
+
temperature=GEMINI_TEMPERATURE,
|
| 94 |
+
)
|
| 95 |
+
if kb_tool:
|
| 96 |
+
tools = [search_docs, search_government_docs, analyze_scenario, search_tool]
|
| 97 |
+
else:
|
| 98 |
+
tools = [search_tool, analyze_scenario]
|
| 99 |
+
tools = tools + mcp_tools
|
| 100 |
+
llm_with_tools = llm.bind_tools(tools)
|
| 101 |
+
|
| 102 |
+
async def llm_node(state: State):
|
| 103 |
+
messages = state["messages"]
|
| 104 |
+
response = await llm_with_tools.ainvoke(messages)
|
| 105 |
+
return {"messages": [response]}
|
| 106 |
+
|
| 107 |
+
builder = StateGraph(State)
|
| 108 |
+
builder.add_node("llm_with_tools", llm_node)
|
| 109 |
+
tool_node = ToolNode(tools=tools, handle_tool_errors=True)
|
| 110 |
+
builder.add_node("tools", tool_node)
|
| 111 |
+
builder.add_conditional_edges("llm_with_tools", tools_condition)
|
| 112 |
+
builder.add_edge("tools", "llm_with_tools")
|
| 113 |
+
builder.add_edge(START, "llm_with_tools")
|
| 114 |
+
builder.add_edge("llm_with_tools", END)
|
| 115 |
+
return builder.compile()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# Build basic graph (no tools, no memory)
|
| 119 |
+
def create_basic_graph():
|
| 120 |
+
llm = ChatGoogleGenerativeAI(
|
| 121 |
+
model=GEMINI_MODEL,
|
| 122 |
+
google_api_key=GOOGLE_API_KEY,
|
| 123 |
+
temperature=GEMINI_TEMPERATURE,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
async def llm_basic_node(state: State):
|
| 127 |
+
messages = state["messages"]
|
| 128 |
+
system_prompt = SystemMessage(
|
| 129 |
+
content="""You are a helpful and friendly voice AI assistant. Your responses should be:
|
| 130 |
+
|
| 131 |
+
- Conversational and natural, as if speaking to a friend
|
| 132 |
+
- Concise but informative - aim for 1-3 sentences unless more detail is specifically requested
|
| 133 |
+
- Clear and easy to understand when spoken aloud
|
| 134 |
+
- Engaging and personable while remaining professional
|
| 135 |
+
- Avoid overly complex language or long lists that are hard to follow in audio format
|
| 136 |
+
|
| 137 |
+
When responding:
|
| 138 |
+
- Use a warm, approachable tone
|
| 139 |
+
- Speak in a natural rhythm suitable for text-to-speech
|
| 140 |
+
- If you need to provide multiple items or steps, break them into digestible chunks
|
| 141 |
+
- Ask clarifying questions when needed to better assist the user
|
| 142 |
+
- Acknowledge when you don't know something rather than guessing
|
| 143 |
+
|
| 144 |
+
Remember that users are interacting with you through voice, so structure your responses to be easily understood when heard rather than read.
|
| 145 |
+
Dont use abbreviations or numerical content in your responses."""
|
| 146 |
+
)
|
| 147 |
+
if not any(isinstance(m, SystemMessage) for m in messages):
|
| 148 |
+
messages.insert(0, system_prompt)
|
| 149 |
+
return {"messages": [llm.invoke(messages)]}
|
| 150 |
+
|
| 151 |
+
builder = StateGraph(State)
|
| 152 |
+
builder.add_node("llm_basic", llm_basic_node)
|
| 153 |
+
builder.set_entry_point("llm_basic")
|
| 154 |
+
builder.add_edge("llm_basic", END)
|
| 155 |
+
return builder.compile() # No checkpointing
|
main.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Voice Bot Application - Entry Point
|
| 3 |
+
|
| 4 |
+
This file has been refactored. The main application logic is now in app.py
|
| 5 |
+
Please run: python app.py or use uvicorn app:app
|
| 6 |
+
|
| 7 |
+
The modular structure:
|
| 8 |
+
- config.py: Configuration and constants
|
| 9 |
+
- audio_services.py: ASR and TTS functionality
|
| 10 |
+
- rag_service.py: Vector store and document search
|
| 11 |
+
- llm_service.py: LangGraph and LLM handling
|
| 12 |
+
- document_service.py: PDF processing and document upload
|
| 13 |
+
- websocket_handler.py: WebSocket connection handling
|
| 14 |
+
- app.py: FastAPI application and routes
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from app import app
|
| 18 |
+
|
| 19 |
+
if __name__ == "__main__":
|
| 20 |
+
import uvicorn
|
| 21 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
rag_service.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 2 |
+
from langchain_core.tools import tool
|
| 3 |
+
from config import EMBEDDING_MODEL_NAME
|
| 4 |
+
from langchain_core.runnables import RunnableConfig
|
| 5 |
+
from typing import List, Dict, Any
|
| 6 |
+
from lancedb_service import lancedb_service
|
| 7 |
+
from scenario_analysis_service import scenario_service
|
| 8 |
+
import logging
|
| 9 |
+
import json
|
| 10 |
+
import asyncio
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger("voicebot")
|
| 13 |
+
|
| 14 |
+
# Setup embedding model
|
| 15 |
+
embedding_model = HuggingFaceEmbeddings(
|
| 16 |
+
model_name=EMBEDDING_MODEL_NAME,
|
| 17 |
+
model_kwargs={
|
| 18 |
+
"device": "cpu",
|
| 19 |
+
"trust_remote_code": True
|
| 20 |
+
},
|
| 21 |
+
encode_kwargs={
|
| 22 |
+
"normalize_embeddings": True
|
| 23 |
+
}
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
async def get_user_knowledge_bases(userid: str) -> List[str]:
|
| 27 |
+
"""Get all knowledge bases for a user"""
|
| 28 |
+
try:
|
| 29 |
+
return await lancedb_service.get_user_knowledge_bases(userid)
|
| 30 |
+
except Exception as e:
|
| 31 |
+
logger.error(f"โ Error fetching knowledge bases: {e}")
|
| 32 |
+
return []
|
| 33 |
+
|
| 34 |
+
async def get_kb_documents(user_id: str, kb_name: str):
|
| 35 |
+
"""Get all documents in a knowledge base"""
|
| 36 |
+
try:
|
| 37 |
+
return await lancedb_service.get_kb_documents(user_id, kb_name)
|
| 38 |
+
except Exception as e:
|
| 39 |
+
logger.error(f"โ Error fetching documents: {e}")
|
| 40 |
+
return []
|
| 41 |
+
|
| 42 |
+
async def delete_document_from_kb(user_id: str, kb_name: str, filename: str):
|
| 43 |
+
"""Delete a document from knowledge base"""
|
| 44 |
+
try:
|
| 45 |
+
return await lancedb_service.delete_document_from_kb(user_id, kb_name, filename)
|
| 46 |
+
except Exception as e:
|
| 47 |
+
logger.error(f"โ Error deleting document: {e}")
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
def search_documents(query: str, limit: int = 5) -> List[Dict[str, Any]]:
|
| 51 |
+
"""
|
| 52 |
+
Synchronous wrapper for searching documents in government knowledge base.
|
| 53 |
+
Returns a list of documents with content for compatibility with existing code.
|
| 54 |
+
"""
|
| 55 |
+
try:
|
| 56 |
+
# Run the async search function synchronously
|
| 57 |
+
loop = asyncio.new_event_loop()
|
| 58 |
+
asyncio.set_event_loop(loop)
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
# Determine which knowledge bases to search based on query content
|
| 62 |
+
knowledge_bases = ["government_docs"] # Default
|
| 63 |
+
|
| 64 |
+
# Add specific knowledge bases based on query keywords
|
| 65 |
+
query_lower = query.lower()
|
| 66 |
+
if any(keyword in query_lower for keyword in ["rajasthan", "pension", "circular", "pay", "rules"]):
|
| 67 |
+
# Use separate table for Rajasthan documents
|
| 68 |
+
return search_rajasthan_documents(query, limit)
|
| 69 |
+
|
| 70 |
+
all_docs = []
|
| 71 |
+
|
| 72 |
+
# Search across all relevant knowledge bases
|
| 73 |
+
for kb in knowledge_bases:
|
| 74 |
+
try:
|
| 75 |
+
docs = loop.run_until_complete(
|
| 76 |
+
lancedb_service.similarity_search(query, "system", kb, k=limit)
|
| 77 |
+
)
|
| 78 |
+
all_docs.extend(docs)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
logger.warning(f"Search failed for knowledge base {kb}: {e}")
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
if not all_docs:
|
| 84 |
+
return []
|
| 85 |
+
|
| 86 |
+
# Sort by relevance score if available and limit results
|
| 87 |
+
all_docs = sorted(all_docs, key=lambda x: getattr(x, 'score', 1.0), reverse=True)[:limit]
|
| 88 |
+
|
| 89 |
+
# Convert to expected format
|
| 90 |
+
results = []
|
| 91 |
+
for doc in all_docs:
|
| 92 |
+
results.append({
|
| 93 |
+
"content": doc.page_content,
|
| 94 |
+
"source": doc.metadata.get('source', 'Unknown'),
|
| 95 |
+
"score": getattr(doc, 'score', 1.0)
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
logger.info(f"๐ Found {len(results)} documents for query: {query}")
|
| 99 |
+
return results
|
| 100 |
+
|
| 101 |
+
finally:
|
| 102 |
+
loop.close()
|
| 103 |
+
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.error(f"โ Error in search_documents: {e}")
|
| 106 |
+
return []
|
| 107 |
+
|
| 108 |
+
def search_rajasthan_documents(query: str, limit: int = 5) -> List[Dict[str, Any]]:
|
| 109 |
+
"""
|
| 110 |
+
Search specifically in the Rajasthan documents table using direct LanceDB query.
|
| 111 |
+
"""
|
| 112 |
+
try:
|
| 113 |
+
import lancedb
|
| 114 |
+
|
| 115 |
+
# Connect to LanceDB
|
| 116 |
+
db = lancedb.connect('./lancedb_data')
|
| 117 |
+
|
| 118 |
+
# Check if rajasthan_documents table exists
|
| 119 |
+
if 'rajasthan_documents' not in db.table_names():
|
| 120 |
+
logger.warning("โ ๏ธ Rajasthan documents table not found")
|
| 121 |
+
return []
|
| 122 |
+
|
| 123 |
+
# Get the table
|
| 124 |
+
tbl = db.open_table('rajasthan_documents')
|
| 125 |
+
|
| 126 |
+
# Create embedding for the query
|
| 127 |
+
query_embedding = embedding_model.embed_query(query)
|
| 128 |
+
|
| 129 |
+
# Search using vector similarity
|
| 130 |
+
search_results = tbl.search(query_embedding).limit(limit).to_pandas()
|
| 131 |
+
|
| 132 |
+
if search_results.empty:
|
| 133 |
+
logger.info(f"๐ No results found in Rajasthan documents for: {query}")
|
| 134 |
+
return []
|
| 135 |
+
|
| 136 |
+
# Convert to expected format
|
| 137 |
+
results = []
|
| 138 |
+
for _, row in search_results.iterrows():
|
| 139 |
+
results.append({
|
| 140 |
+
"content": row['content'],
|
| 141 |
+
"source": row['filename'],
|
| 142 |
+
"score": float(row.get('_distance', 1.0)) # LanceDB returns _distance
|
| 143 |
+
})
|
| 144 |
+
|
| 145 |
+
logger.info(f"๐ Found {len(results)} Rajasthan documents for query: {query}")
|
| 146 |
+
return results
|
| 147 |
+
|
| 148 |
+
except Exception as e:
|
| 149 |
+
logger.error(f"โ Error searching Rajasthan documents: {e}")
|
| 150 |
+
return []
|
| 151 |
+
|
| 152 |
+
@tool
|
| 153 |
+
async def search_docs(query: str, config: RunnableConfig) -> str:
|
| 154 |
+
"""Search the knowledge base for relevant context within a specific knowledge base."""
|
| 155 |
+
userid = config["configurable"].get("thread_id")
|
| 156 |
+
knowledge_base = config["configurable"].get("knowledge_base", "government_docs")
|
| 157 |
+
|
| 158 |
+
try:
|
| 159 |
+
# Search in the specified knowledge base
|
| 160 |
+
docs = await lancedb_service.similarity_search(query, userid, knowledge_base)
|
| 161 |
+
|
| 162 |
+
if not docs:
|
| 163 |
+
return "No relevant documents found in the knowledge base."
|
| 164 |
+
|
| 165 |
+
context = "\n\n".join([doc.page_content for doc in docs])
|
| 166 |
+
return f"๐ Found {len(docs)} relevant documents:\n\n{context}"
|
| 167 |
+
|
| 168 |
+
except Exception as e:
|
| 169 |
+
logger.error(f"โ Error searching documents: {e}")
|
| 170 |
+
return "Error occurred while searching documents."
|
| 171 |
+
|
| 172 |
+
@tool
|
| 173 |
+
async def search_government_docs(query: str, config: RunnableConfig) -> str:
|
| 174 |
+
"""Search government documents for relevant information and policies."""
|
| 175 |
+
try:
|
| 176 |
+
# Search specifically in government_docs knowledge base
|
| 177 |
+
docs = await lancedb_service.similarity_search(query, "system", "government_docs")
|
| 178 |
+
|
| 179 |
+
if not docs:
|
| 180 |
+
return "No relevant government documents found for your query."
|
| 181 |
+
|
| 182 |
+
context = "\n\n".join([doc.page_content for doc in docs])
|
| 183 |
+
sources = list(set([doc.metadata.get('source', 'Unknown') for doc in docs]))
|
| 184 |
+
|
| 185 |
+
result = f"๐ Found {len(docs)} relevant government documents:\n\n{context}"
|
| 186 |
+
if sources:
|
| 187 |
+
result += f"\n\n๐ Sources: {', '.join(sources)}"
|
| 188 |
+
|
| 189 |
+
return result
|
| 190 |
+
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"โ Error searching government documents: {e}")
|
| 193 |
+
return "Error occurred while searching government documents."
|
| 194 |
+
|
| 195 |
+
@tool
|
| 196 |
+
async def analyze_scenario(scenario_query: str, config: RunnableConfig) -> str:
|
| 197 |
+
"""
|
| 198 |
+
Analyze government scenarios and create visualizations including charts, graphs, and diagrams.
|
| 199 |
+
Use this tool when users ask for scenario analysis, data visualization, charts, graphs, or diagrams
|
| 200 |
+
related to government processes, budgets, policies, organizational structures, or performance metrics.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
scenario_query: Description of the scenario to analyze (e.g., "budget analysis for health department",
|
| 204 |
+
"policy implementation timeline", "organizational structure", "performance metrics")
|
| 205 |
+
"""
|
| 206 |
+
try:
|
| 207 |
+
logger.info(f"๐ Analyzing scenario: {scenario_query}")
|
| 208 |
+
|
| 209 |
+
# Parse the scenario query to determine type and extract data
|
| 210 |
+
scenario_data = await _parse_scenario_query(scenario_query)
|
| 211 |
+
|
| 212 |
+
# Perform scenario analysis
|
| 213 |
+
result = await scenario_service.analyze_government_scenario(scenario_data)
|
| 214 |
+
|
| 215 |
+
if result.get("success", False):
|
| 216 |
+
# Format response with images
|
| 217 |
+
response = f"๐ **Scenario Analysis Complete!**\n\n"
|
| 218 |
+
response += result.get("analysis", "")
|
| 219 |
+
response += f"\n\n๐ผ๏ธ **Generated {len(result.get('images', []))} visualization(s)**"
|
| 220 |
+
|
| 221 |
+
# Add image information for frontend rendering
|
| 222 |
+
if result.get("images"):
|
| 223 |
+
response += "\n\n**SCENARIO_IMAGES_START**\n"
|
| 224 |
+
response += json.dumps(result["images"])
|
| 225 |
+
response += "\n**SCENARIO_IMAGES_END**"
|
| 226 |
+
|
| 227 |
+
return response
|
| 228 |
+
else:
|
| 229 |
+
return f"โ Error in scenario analysis: {result.get('error', 'Unknown error')}"
|
| 230 |
+
|
| 231 |
+
except Exception as e:
|
| 232 |
+
logger.error(f"โ Error in scenario analysis tool: {e}")
|
| 233 |
+
return f"Error occurred while analyzing scenario: {str(e)}"
|
| 234 |
+
|
| 235 |
+
async def _parse_scenario_query(query: str) -> Dict[str, Any]:
|
| 236 |
+
"""Parse scenario query to determine type and extract relevant data"""
|
| 237 |
+
query_lower = query.lower()
|
| 238 |
+
|
| 239 |
+
# Determine scenario type based on keywords
|
| 240 |
+
if any(word in query_lower for word in ["budget", "financial", "expenditure", "allocation", "funding"]):
|
| 241 |
+
scenario_type = "budget"
|
| 242 |
+
# Extract budget data if mentioned in query
|
| 243 |
+
data = _extract_budget_data(query)
|
| 244 |
+
elif any(word in query_lower for word in ["policy", "implementation", "timeline", "plan", "strategy"]):
|
| 245 |
+
scenario_type = "policy"
|
| 246 |
+
data = _extract_policy_data(query)
|
| 247 |
+
elif any(word in query_lower for word in ["organization", "hierarchy", "structure", "reporting", "org"]):
|
| 248 |
+
scenario_type = "organization"
|
| 249 |
+
data = _extract_org_data(query)
|
| 250 |
+
elif any(word in query_lower for word in ["performance", "metrics", "kpi", "efficiency", "evaluation"]):
|
| 251 |
+
scenario_type = "performance"
|
| 252 |
+
data = _extract_performance_data(query)
|
| 253 |
+
elif any(word in query_lower for word in ["workflow", "process", "flow", "procedure", "steps"]):
|
| 254 |
+
scenario_type = "workflow"
|
| 255 |
+
data = _extract_workflow_data(query)
|
| 256 |
+
else:
|
| 257 |
+
scenario_type = "general"
|
| 258 |
+
data = {}
|
| 259 |
+
|
| 260 |
+
return {
|
| 261 |
+
"type": scenario_type,
|
| 262 |
+
"title": f"Government {scenario_type.title()} Analysis",
|
| 263 |
+
"data": data
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
def _extract_budget_data(query: str) -> Dict[str, Any]:
|
| 267 |
+
"""Extract budget-related data from query"""
|
| 268 |
+
# This could be enhanced with NLP to extract actual numbers and departments
|
| 269 |
+
# For now, return sample data structure
|
| 270 |
+
return {}
|
| 271 |
+
|
| 272 |
+
def _extract_policy_data(query: str) -> Dict[str, Any]:
|
| 273 |
+
"""Extract policy-related data from query"""
|
| 274 |
+
return {}
|
| 275 |
+
|
| 276 |
+
def _extract_org_data(query: str) -> Dict[str, Any]:
|
| 277 |
+
"""Extract organizational data from query"""
|
| 278 |
+
return {}
|
| 279 |
+
|
| 280 |
+
def _extract_performance_data(query: str) -> Dict[str, Any]:
|
| 281 |
+
"""Extract performance data from query"""
|
| 282 |
+
return {}
|
| 283 |
+
|
| 284 |
+
def _extract_workflow_data(query: str) -> Dict[str, Any]:
|
| 285 |
+
"""Extract workflow data from query"""
|
| 286 |
+
return {}
|
| 287 |
+
|
| 288 |
+
if __name__ == "__main__":
|
| 289 |
+
import asyncio
|
| 290 |
+
|
| 291 |
+
async def test_search():
|
| 292 |
+
print("๐ Testing search_docs RAG tool with LanceDB vector store...\n")
|
| 293 |
+
|
| 294 |
+
test_user_id = "test_user_123"
|
| 295 |
+
test_knowledge_base = "test_kb"
|
| 296 |
+
|
| 297 |
+
while True:
|
| 298 |
+
user_input = input("Enter a query (or 'exit'): ").strip()
|
| 299 |
+
if user_input.lower() == "exit":
|
| 300 |
+
break
|
| 301 |
+
|
| 302 |
+
kb_input = input(f"Knowledge base (current: {test_knowledge_base}, press Enter to keep): ").strip()
|
| 303 |
+
if kb_input:
|
| 304 |
+
test_knowledge_base = kb_input
|
| 305 |
+
|
| 306 |
+
try:
|
| 307 |
+
result = await search_docs.ainvoke(
|
| 308 |
+
{"query": user_input},
|
| 309 |
+
config=RunnableConfig(
|
| 310 |
+
configurable={
|
| 311 |
+
"thread_id": test_user_id,
|
| 312 |
+
"knowledge_base": test_knowledge_base
|
| 313 |
+
}
|
| 314 |
+
)
|
| 315 |
+
)
|
| 316 |
+
print(f"\n๐ Results from '{test_knowledge_base}' knowledge base:\n")
|
| 317 |
+
print(result)
|
| 318 |
+
print("\n" + "="*50 + "\n")
|
| 319 |
+
except Exception as e:
|
| 320 |
+
print(f"โ Error: {e}")
|
| 321 |
+
|
| 322 |
+
asyncio.run(test_search())
|
requirements.txt
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dotenv>=0.9.9
|
| 2 |
+
fastapi>=0.115.14
|
| 3 |
+
gradio>=4.44.0
|
| 4 |
+
requests>=2.31.0
|
| 5 |
+
langchain>=0.3.26
|
| 6 |
+
langchain-community>=0.3.27
|
| 7 |
+
langchain-huggingface>=0.3.0
|
| 8 |
+
langchain-google-genai>=2.0.1
|
| 9 |
+
langchain-groq>=0.3.0
|
| 10 |
+
langchain-tavily>=0.2.7
|
| 11 |
+
langgraph>=0.5.1
|
| 12 |
+
langsmith>=0.4.4
|
| 13 |
+
lancedb>=0.13.0
|
| 14 |
+
google-generativeai>=0.8.1
|
| 15 |
+
pdfplumber>=0.11.7
|
| 16 |
+
pip>=25.1.1
|
| 17 |
+
pyjwt>=2.10.1
|
| 18 |
+
python-multipart>=0.0.20
|
| 19 |
+
sentence-transformers>=5.0.0
|
| 20 |
+
uvicorn[standard]>=0.35.0
|
| 21 |
+
pandas>=2.0.0
|
| 22 |
+
pyarrow>=14.0.0
|
| 23 |
+
einops>=0.8.0
|
| 24 |
+
matplotlib>=3.7.0
|
| 25 |
+
seaborn>=0.12.0
|
| 26 |
+
plotly>=5.15.0
|
| 27 |
+
networkx>=3.1
|
| 28 |
+
pillow>=10.0.0
|
| 29 |
+
edge-tts>=6.1.0
|
| 30 |
+
whisper>=1.1.10
|
| 31 |
+
pydub>=0.25.1
|
| 32 |
+
websockets>=11.0.0
|
voice_service.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Voice Service for optional Text-to-Speech (TTS) and Automatic Speech Recognition (ASR)
|
| 3 |
+
Provides voice interaction capabilities when enabled by user.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import logging
|
| 8 |
+
import tempfile
|
| 9 |
+
import os
|
| 10 |
+
from typing import Optional, Dict, Any
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
from config import (
|
| 14 |
+
ENABLE_VOICE_FEATURES, TTS_PROVIDER, ASR_PROVIDER,
|
| 15 |
+
VOICE_LANGUAGE, DEFAULT_VOICE_SPEED
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("voicebot")
|
| 19 |
+
|
| 20 |
+
class VoiceService:
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.voice_enabled = ENABLE_VOICE_FEATURES
|
| 23 |
+
self.tts_provider = TTS_PROVIDER
|
| 24 |
+
self.asr_provider = ASR_PROVIDER
|
| 25 |
+
self.language = VOICE_LANGUAGE
|
| 26 |
+
self.voice_speed = DEFAULT_VOICE_SPEED
|
| 27 |
+
|
| 28 |
+
# Initialize services if voice is enabled
|
| 29 |
+
if self.voice_enabled:
|
| 30 |
+
self._init_tts_service()
|
| 31 |
+
self._init_asr_service()
|
| 32 |
+
logger.info(f"๐ค Voice Service initialized - TTS: {self.tts_provider}, ASR: {self.asr_provider}")
|
| 33 |
+
else:
|
| 34 |
+
logger.info("๐ Voice features disabled")
|
| 35 |
+
|
| 36 |
+
def _init_tts_service(self):
|
| 37 |
+
"""Initialize Text-to-Speech service"""
|
| 38 |
+
try:
|
| 39 |
+
if self.tts_provider == "edge-tts":
|
| 40 |
+
import edge_tts
|
| 41 |
+
self.tts_available = True
|
| 42 |
+
logger.info("โ
Edge TTS initialized")
|
| 43 |
+
elif self.tts_provider == "openai-tts":
|
| 44 |
+
# OpenAI TTS would require OpenAI API key
|
| 45 |
+
self.tts_available = False
|
| 46 |
+
logger.info("โ ๏ธ OpenAI TTS not configured")
|
| 47 |
+
else:
|
| 48 |
+
self.tts_available = False
|
| 49 |
+
logger.warning(f"โ ๏ธ Unknown TTS provider: {self.tts_provider}")
|
| 50 |
+
except ImportError as e:
|
| 51 |
+
self.tts_available = False
|
| 52 |
+
logger.warning(f"โ ๏ธ TTS dependencies not available: {e}")
|
| 53 |
+
|
| 54 |
+
def _init_asr_service(self):
|
| 55 |
+
"""Initialize Automatic Speech Recognition service"""
|
| 56 |
+
try:
|
| 57 |
+
if self.asr_provider == "whisper":
|
| 58 |
+
import whisper
|
| 59 |
+
# Use base model for balance between speed and accuracy
|
| 60 |
+
self.whisper_model = whisper.load_model("base")
|
| 61 |
+
self.asr_available = True
|
| 62 |
+
logger.info("โ
Whisper ASR initialized (base model for accuracy)")
|
| 63 |
+
elif self.asr_provider == "browser-native":
|
| 64 |
+
# Browser-based ASR doesn't require server-side setup
|
| 65 |
+
self.asr_available = True
|
| 66 |
+
logger.info("โ
Browser ASR configured")
|
| 67 |
+
else:
|
| 68 |
+
self.asr_available = False
|
| 69 |
+
logger.warning(f"โ ๏ธ Unknown ASR provider: {self.asr_provider}")
|
| 70 |
+
except ImportError as e:
|
| 71 |
+
self.asr_available = False
|
| 72 |
+
logger.warning(f"โ ๏ธ ASR dependencies not available: {e}")
|
| 73 |
+
|
| 74 |
+
def _get_language_code(self, user_language: str = None) -> str:
|
| 75 |
+
"""
|
| 76 |
+
Convert user language preference to Whisper language code
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
user_language: User's language preference ('english', 'hindi', 'hi-IN', etc.)
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Two-letter language code for Whisper (e.g., 'en', 'hi')
|
| 83 |
+
"""
|
| 84 |
+
if not user_language:
|
| 85 |
+
# Fallback to default config language
|
| 86 |
+
return self.language.split('-')[0] if self.language else 'en'
|
| 87 |
+
|
| 88 |
+
# Handle different language format inputs
|
| 89 |
+
user_lang_lower = user_language.lower()
|
| 90 |
+
|
| 91 |
+
# Map common language names to codes
|
| 92 |
+
language_mapping = {
|
| 93 |
+
'english': 'en',
|
| 94 |
+
'hindi': 'hi',
|
| 95 |
+
'hinglish': 'hi', # Treat Hinglish as Hindi for better results
|
| 96 |
+
'en': 'en',
|
| 97 |
+
'hi': 'hi',
|
| 98 |
+
'en-in': 'en',
|
| 99 |
+
'hi-in': 'hi',
|
| 100 |
+
'en-us': 'en'
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
# Extract base language if it's a locale code (e.g., 'hi-IN' -> 'hi')
|
| 104 |
+
if '-' in user_lang_lower:
|
| 105 |
+
base_lang = user_lang_lower.split('-')[0]
|
| 106 |
+
return language_mapping.get(base_lang, 'en')
|
| 107 |
+
|
| 108 |
+
return language_mapping.get(user_lang_lower, 'en')
|
| 109 |
+
|
| 110 |
+
def _get_default_voice(self) -> str:
|
| 111 |
+
"""Get default voice based on language setting"""
|
| 112 |
+
language_voices = {
|
| 113 |
+
'hi-IN': 'hi-IN-SwaraNeural', # Hindi (India) female voice
|
| 114 |
+
'en-IN': 'en-IN-NeerjaNeural', # English (India) female voice
|
| 115 |
+
'en-US': 'en-US-AriaNeural', # English (US) female voice
|
| 116 |
+
'es-ES': 'es-ES-ElviraNeural', # Spanish (Spain) female voice
|
| 117 |
+
'fr-FR': 'fr-FR-DeniseNeural', # French (France) female voice
|
| 118 |
+
'de-DE': 'de-DE-KatjaNeural', # German (Germany) female voice
|
| 119 |
+
'ja-JP': 'ja-JP-NanamiNeural', # Japanese female voice
|
| 120 |
+
'ko-KR': 'ko-KR-SunHiNeural', # Korean female voice
|
| 121 |
+
'zh-CN': 'zh-CN-XiaoxiaoNeural' # Chinese (Simplified) female voice
|
| 122 |
+
}
|
| 123 |
+
return language_voices.get(self.language, 'en-US-AriaNeural')
|
| 124 |
+
|
| 125 |
+
async def text_to_speech(self, text: str, voice: str = None) -> Optional[bytes]:
|
| 126 |
+
"""
|
| 127 |
+
Convert text to speech audio
|
| 128 |
+
Returns audio bytes or None if TTS not available
|
| 129 |
+
"""
|
| 130 |
+
if not self.voice_enabled or not self.tts_available:
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
# Use default voice for the configured language if no voice specified
|
| 134 |
+
if voice is None:
|
| 135 |
+
voice = self._get_default_voice()
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
if self.tts_provider == "edge-tts":
|
| 139 |
+
import edge_tts
|
| 140 |
+
|
| 141 |
+
# Create TTS communication
|
| 142 |
+
communicate = edge_tts.Communicate(text, voice, rate=f"{int((self.voice_speed - 1) * 100):+d}%")
|
| 143 |
+
|
| 144 |
+
# Generate audio
|
| 145 |
+
audio_data = b""
|
| 146 |
+
async for chunk in communicate.stream():
|
| 147 |
+
if chunk["type"] == "audio":
|
| 148 |
+
audio_data += chunk["data"]
|
| 149 |
+
|
| 150 |
+
return audio_data
|
| 151 |
+
|
| 152 |
+
except Exception as e:
|
| 153 |
+
logger.error(f"โ TTS Error: {e}")
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
+
async def speech_to_text(self, audio_file_path: str, user_language: str = None) -> Optional[str]:
|
| 157 |
+
"""
|
| 158 |
+
Convert speech audio to text
|
| 159 |
+
Returns transcribed text or None if ASR not available
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
audio_file_path: Path to the audio file
|
| 163 |
+
user_language: User's preferred language (e.g., 'english', 'hindi', 'hi-IN')
|
| 164 |
+
"""
|
| 165 |
+
if not self.voice_enabled or not self.asr_available:
|
| 166 |
+
return None
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
if self.asr_provider == "whisper":
|
| 170 |
+
# Determine language code based on user preference or default
|
| 171 |
+
language_code = self._get_language_code(user_language)
|
| 172 |
+
|
| 173 |
+
logger.info(f"๐ค Using Whisper with language: {language_code} (user_pref: {user_language})")
|
| 174 |
+
|
| 175 |
+
# Use enhanced transcription options for better accuracy
|
| 176 |
+
transcribe_options = {
|
| 177 |
+
"fp16": False, # Use FP32 for better accuracy on CPU
|
| 178 |
+
"temperature": 0.0, # Deterministic output
|
| 179 |
+
"best_of": 1, # Use best transcription
|
| 180 |
+
"beam_size": 5, # Better beam search
|
| 181 |
+
"patience": 1.0, # Wait for better results
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
if language_code and language_code != 'en':
|
| 185 |
+
transcribe_options["language"] = language_code
|
| 186 |
+
result = self.whisper_model.transcribe(audio_file_path, **transcribe_options)
|
| 187 |
+
logger.info(f"๐ค {language_code.upper()} transcription result: {result.get('text', '')}")
|
| 188 |
+
else:
|
| 189 |
+
result = self.whisper_model.transcribe(audio_file_path, **transcribe_options)
|
| 190 |
+
logger.info(f"๐ค English transcription result: {result.get('text', '')}")
|
| 191 |
+
|
| 192 |
+
transcribed_text = result["text"].strip()
|
| 193 |
+
|
| 194 |
+
# Log confidence/quality metrics if available
|
| 195 |
+
if "segments" in result and result["segments"]:
|
| 196 |
+
avg_confidence = sum(seg.get("no_speech_prob", 0) for seg in result["segments"]) / len(result["segments"])
|
| 197 |
+
logger.info(f"๐ค Average confidence: {1-avg_confidence:.2f}")
|
| 198 |
+
|
| 199 |
+
return transcribed_text
|
| 200 |
+
|
| 201 |
+
except Exception as e:
|
| 202 |
+
logger.error(f"โ ASR Error: {e}")
|
| 203 |
+
return None
|
| 204 |
+
|
| 205 |
+
def get_available_voices(self) -> Dict[str, Any]:
|
| 206 |
+
"""Get list of available TTS voices"""
|
| 207 |
+
if not self.voice_enabled or self.tts_provider != "edge-tts":
|
| 208 |
+
return {}
|
| 209 |
+
|
| 210 |
+
# Common Edge TTS voices
|
| 211 |
+
voices = {
|
| 212 |
+
"english": {
|
| 213 |
+
"female": ["en-US-AriaNeural", "en-US-JennyNeural", "en-GB-SoniaNeural"],
|
| 214 |
+
"male": ["en-US-GuyNeural", "en-US-DavisNeural", "en-GB-RyanNeural"]
|
| 215 |
+
},
|
| 216 |
+
"multilingual": {
|
| 217 |
+
"spanish": ["es-ES-ElviraNeural", "es-MX-DaliaNeural"],
|
| 218 |
+
"french": ["fr-FR-DeniseNeural", "fr-CA-SylvieNeural"],
|
| 219 |
+
"german": ["de-DE-KatjaNeural", "de-AT-IngridNeural"],
|
| 220 |
+
"italian": ["it-IT-ElsaNeural", "it-IT-IsabellaNeural"],
|
| 221 |
+
"hindi": ["hi-IN-SwaraNeural", "hi-IN-MadhurNeural"]
|
| 222 |
+
}
|
| 223 |
+
}
|
| 224 |
+
return voices
|
| 225 |
+
|
| 226 |
+
def create_voice_response_with_guidance(self,
|
| 227 |
+
answer: str,
|
| 228 |
+
suggested_resources: list = None,
|
| 229 |
+
redirect_info: str = None) -> str:
|
| 230 |
+
"""
|
| 231 |
+
Create a comprehensive voice response with guidance and redirection
|
| 232 |
+
"""
|
| 233 |
+
response_parts = []
|
| 234 |
+
|
| 235 |
+
# Main answer
|
| 236 |
+
response_parts.append(answer)
|
| 237 |
+
|
| 238 |
+
# Add guidance for further information
|
| 239 |
+
if suggested_resources:
|
| 240 |
+
response_parts.append("\nFor more detailed information, I recommend checking:")
|
| 241 |
+
for resource in suggested_resources:
|
| 242 |
+
response_parts.append(f"โข {resource}")
|
| 243 |
+
|
| 244 |
+
# Add redirection information
|
| 245 |
+
if redirect_info:
|
| 246 |
+
response_parts.append(f"\nYou can also {redirect_info}")
|
| 247 |
+
|
| 248 |
+
# Add helpful voice interaction tips
|
| 249 |
+
response_parts.append("\nIs there anything specific you'd like me to explain further? Just ask!")
|
| 250 |
+
|
| 251 |
+
return " ".join(response_parts)
|
| 252 |
+
|
| 253 |
+
def generate_redirect_suggestions(self, topic: str, query_type: str) -> Dict[str, Any]:
|
| 254 |
+
"""
|
| 255 |
+
Generate contextual redirect suggestions based on the topic and query type
|
| 256 |
+
"""
|
| 257 |
+
suggestions = {
|
| 258 |
+
"documents": [],
|
| 259 |
+
"websites": [],
|
| 260 |
+
"departments": [],
|
| 261 |
+
"redirect_text": ""
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
# Government policy topics
|
| 265 |
+
if "digital india" in topic.lower():
|
| 266 |
+
suggestions["documents"] = [
|
| 267 |
+
"Digital India Policy Framework 2023",
|
| 268 |
+
"E-Governance Implementation Guidelines"
|
| 269 |
+
]
|
| 270 |
+
suggestions["websites"] = ["digitalindia.gov.in", "meity.gov.in"]
|
| 271 |
+
suggestions["departments"] = ["Ministry of Electronics & IT"]
|
| 272 |
+
suggestions["redirect_text"] = "visit the official Digital India portal or contact your local e-governance center"
|
| 273 |
+
|
| 274 |
+
elif "education" in topic.lower():
|
| 275 |
+
suggestions["documents"] = [
|
| 276 |
+
"National Education Policy 2020",
|
| 277 |
+
"Sarva Shiksha Abhiyan Guidelines"
|
| 278 |
+
]
|
| 279 |
+
suggestions["websites"] = ["education.gov.in", "mhrd.gov.in"]
|
| 280 |
+
suggestions["departments"] = ["Ministry of Education"]
|
| 281 |
+
suggestions["redirect_text"] = "contact your District Education Officer or visit the nearest education department office"
|
| 282 |
+
|
| 283 |
+
elif "health" in topic.lower():
|
| 284 |
+
suggestions["documents"] = [
|
| 285 |
+
"National Health Policy 2017",
|
| 286 |
+
"Ayushman Bharat Implementation Guide"
|
| 287 |
+
]
|
| 288 |
+
suggestions["websites"] = ["mohfw.gov.in", "pmjay.gov.in"]
|
| 289 |
+
suggestions["departments"] = ["Ministry of Health & Family Welfare"]
|
| 290 |
+
suggestions["redirect_text"] = "visit your nearest Primary Health Center or call the health helpline"
|
| 291 |
+
|
| 292 |
+
elif "employment" in topic.lower() or "job" in topic.lower():
|
| 293 |
+
suggestions["documents"] = [
|
| 294 |
+
"Employment Generation Schemes",
|
| 295 |
+
"Skill Development Programs Guide"
|
| 296 |
+
]
|
| 297 |
+
suggestions["websites"] = ["nrega.nic.in", "msde.gov.in"]
|
| 298 |
+
suggestions["departments"] = ["Ministry of Rural Development", "Ministry of Skill Development"]
|
| 299 |
+
suggestions["redirect_text"] = "visit your local employment exchange or skill development center"
|
| 300 |
+
|
| 301 |
+
# Default for other topics
|
| 302 |
+
if not suggestions["redirect_text"]:
|
| 303 |
+
suggestions["redirect_text"] = "contact the relevant government department or visit your local district collector's office"
|
| 304 |
+
|
| 305 |
+
return suggestions
|
| 306 |
+
|
| 307 |
+
def is_voice_enabled(self) -> bool:
|
| 308 |
+
"""Check if voice features are enabled"""
|
| 309 |
+
return self.voice_enabled
|
| 310 |
+
|
| 311 |
+
def get_voice_status(self) -> Dict[str, Any]:
|
| 312 |
+
"""Get current voice service status"""
|
| 313 |
+
return {
|
| 314 |
+
"voice_enabled": self.voice_enabled,
|
| 315 |
+
"tts_available": getattr(self, 'tts_available', False),
|
| 316 |
+
"asr_available": getattr(self, 'asr_available', False),
|
| 317 |
+
"tts_provider": self.tts_provider,
|
| 318 |
+
"asr_provider": self.asr_provider,
|
| 319 |
+
"language": self.language,
|
| 320 |
+
"voice_speed": self.voice_speed
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
# Global instance
|
| 324 |
+
voice_service = VoiceService()
|
voice_websocket_server.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Voice-enabled WebSocket server that combines the full voice backend with our document search
|
| 4 |
+
"""
|
| 5 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
import uvicorn
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
import lancedb
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import asyncio
|
| 13 |
+
import os
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
from dataclasses import asdict, is_dataclass
|
| 16 |
+
|
| 17 |
+
# Try to import voice services, fallback if not available
|
| 18 |
+
try:
|
| 19 |
+
from hybrid_llm_service import HybridLLMService
|
| 20 |
+
from voice_service import VoiceService
|
| 21 |
+
from settings_api import router as settings_router
|
| 22 |
+
from policy_simulator_api import router as policy_simulator_router
|
| 23 |
+
VOICE_AVAILABLE = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
VOICE_AVAILABLE = False
|
| 26 |
+
logging.warning("Voice services not available, text-only mode")
|
| 27 |
+
|
| 28 |
+
load_dotenv()
|
| 29 |
+
|
| 30 |
+
# Configure logging
|
| 31 |
+
logging.basicConfig(level=logging.INFO)
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
# Simple response cache for common queries
|
| 35 |
+
response_cache = {}
|
| 36 |
+
MAX_CACHE_SIZE = 100
|
| 37 |
+
|
| 38 |
+
app = FastAPI()
|
| 39 |
+
|
| 40 |
+
# Include API routers
|
| 41 |
+
if VOICE_AVAILABLE:
|
| 42 |
+
app.include_router(settings_router)
|
| 43 |
+
app.include_router(policy_simulator_router)
|
| 44 |
+
|
| 45 |
+
# Enable CORS - Include both local development and production origins
|
| 46 |
+
allowed_origins = [
|
| 47 |
+
"http://localhost:5176", "http://localhost:5177",
|
| 48 |
+
"http://127.0.0.1:5176", "http://127.0.0.1:5177",
|
| 49 |
+
"http://localhost:3000", "http://localhost:5173",
|
| 50 |
+
"https://*.vercel.app", "https://*.hf.space"
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
# Add any custom origins from environment
|
| 54 |
+
if os.getenv("ALLOWED_ORIGINS"):
|
| 55 |
+
try:
|
| 56 |
+
custom_origins = eval(os.getenv("ALLOWED_ORIGINS"))
|
| 57 |
+
if isinstance(custom_origins, list):
|
| 58 |
+
allowed_origins.extend(custom_origins)
|
| 59 |
+
except:
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
app.add_middleware(
|
| 63 |
+
CORSMiddleware,
|
| 64 |
+
allow_origins=["*"] if "*" in str(allowed_origins) else allowed_origins,
|
| 65 |
+
allow_credentials=True,
|
| 66 |
+
allow_methods=["*"],
|
| 67 |
+
allow_headers=["*"],
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Initialize services if available
|
| 71 |
+
if VOICE_AVAILABLE:
|
| 72 |
+
try:
|
| 73 |
+
hybrid_llm_service = HybridLLMService()
|
| 74 |
+
voice_service = VoiceService()
|
| 75 |
+
logger.info("โ
Voice services initialized")
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.warning(f"โ ๏ธ Voice services failed to initialize: {e}")
|
| 78 |
+
VOICE_AVAILABLE = False
|
| 79 |
+
|
| 80 |
+
def serialize_for_json(obj):
|
| 81 |
+
"""Custom JSON serializer for policy simulation objects"""
|
| 82 |
+
if is_dataclass(obj):
|
| 83 |
+
return asdict(obj)
|
| 84 |
+
elif hasattr(obj, '__dict__'):
|
| 85 |
+
return obj.__dict__
|
| 86 |
+
elif isinstance(obj, (list, tuple)):
|
| 87 |
+
return [serialize_for_json(item) for item in obj]
|
| 88 |
+
elif isinstance(obj, dict):
|
| 89 |
+
return {key: serialize_for_json(value) for key, value in obj.items()}
|
| 90 |
+
else:
|
| 91 |
+
return obj
|
| 92 |
+
|
| 93 |
+
def search_documents_simple(query: str):
|
| 94 |
+
"""Simple document search without embeddings"""
|
| 95 |
+
try:
|
| 96 |
+
db = lancedb.connect('./lancedb_data')
|
| 97 |
+
|
| 98 |
+
# Check for Rajasthan documents first
|
| 99 |
+
if 'rajasthan_documents' in db.table_names():
|
| 100 |
+
tbl = db.open_table('rajasthan_documents')
|
| 101 |
+
df = tbl.to_pandas()
|
| 102 |
+
|
| 103 |
+
# Enhanced search for Rajasthan/pension queries
|
| 104 |
+
query_lower = query.lower()
|
| 105 |
+
is_pension_query = any(keyword in query_lower for keyword in [
|
| 106 |
+
'pension', 'เคชเฅเคเคถเคจ', 'เคตเฅเคฆเฅเคงเคพเคตเคธเฅเคฅเคพ', 'เคธเคพเคฎเคพเคเคฟเค', 'เคญเคคเฅเคคเคพ', 'allowance',
|
| 107 |
+
'old age', 'social security', 'retirement', 'เคธเฅเคตเคพเคจเคฟเคตเฅเคคเฅเคคเคฟ'
|
| 108 |
+
])
|
| 109 |
+
|
| 110 |
+
if is_pension_query or 'rajasthan' in query_lower:
|
| 111 |
+
# Enhanced pension search with more keywords
|
| 112 |
+
pension_filter = df['content'].str.contains(
|
| 113 |
+
'pension|Pension|เคชเฅเคเคถเคจ|เคตเฅเคฆเฅเคงเคพเคตเคธเฅเคฅเคพ|เคธเคพเคฎเคพเคเคฟเค|เคญเคคเฅเคคเคพ|allowance|old.age|social.security|retirement|เคธเฅเคตเคพเคจเคฟเคตเฅเคคเฅเคคเคฟ|scheme|เคฏเฅเคเคจเคพ',
|
| 114 |
+
case=False, na=False, regex=True
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
relevant_docs = df[pension_filter]
|
| 118 |
+
|
| 119 |
+
if not relevant_docs.empty:
|
| 120 |
+
# Sort by relevance
|
| 121 |
+
def score_relevance(content):
|
| 122 |
+
keywords = ['pension', 'เคชเฅเคเคถเคจ', 'เคตเฅเคฆเฅเคงเคพเคตเคธเฅเคฅเคพ', 'เคธเคพเคฎเคพเคเคฟเค', 'เคญเคคเฅเคคเคพ', 'allowance', 'old age']
|
| 123 |
+
return sum(1 for keyword in keywords if keyword in content.lower())
|
| 124 |
+
|
| 125 |
+
relevant_docs = relevant_docs.copy()
|
| 126 |
+
relevant_docs['relevance_score'] = relevant_docs['content'].apply(score_relevance)
|
| 127 |
+
relevant_docs = relevant_docs.sort_values('relevance_score', ascending=False)
|
| 128 |
+
|
| 129 |
+
results = []
|
| 130 |
+
for _, row in relevant_docs.head(5).iterrows():
|
| 131 |
+
results.append({
|
| 132 |
+
"content": row['content'][:800],
|
| 133 |
+
"filename": row['filename']
|
| 134 |
+
})
|
| 135 |
+
return results, "rajasthan_pension_documents"
|
| 136 |
+
|
| 137 |
+
return [], "none"
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"Search error: {e}")
|
| 141 |
+
return [], "error"
|
| 142 |
+
|
| 143 |
+
async def get_llm_response(query: str, search_results: list):
|
| 144 |
+
"""Get response using available LLM service with caching"""
|
| 145 |
+
# Create cache key based on query and search results
|
| 146 |
+
cache_key = f"{query}_{len(search_results) if search_results else 0}"
|
| 147 |
+
|
| 148 |
+
# Check cache first
|
| 149 |
+
if cache_key in response_cache:
|
| 150 |
+
logger.info(f"๐ฆ Cache hit for query: {query[:50]}...")
|
| 151 |
+
return response_cache[cache_key]
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
if VOICE_AVAILABLE and hybrid_llm_service:
|
| 155 |
+
# Use the hybrid LLM service
|
| 156 |
+
if search_results:
|
| 157 |
+
context = "\\n\\n".join([f"Document: {doc['filename']}\\nContent: {doc['content']}" for doc in search_results])
|
| 158 |
+
enhanced_query = f"Based on these Rajasthan government documents, please answer: {query}\\n\\nDocuments:\\n{context}"
|
| 159 |
+
else:
|
| 160 |
+
enhanced_query = query
|
| 161 |
+
|
| 162 |
+
response = await hybrid_llm_service.get_response(enhanced_query)
|
| 163 |
+
|
| 164 |
+
# Cache the response
|
| 165 |
+
if len(response_cache) >= MAX_CACHE_SIZE:
|
| 166 |
+
# Remove oldest entry
|
| 167 |
+
response_cache.pop(next(iter(response_cache)))
|
| 168 |
+
response_cache[cache_key] = response
|
| 169 |
+
|
| 170 |
+
return response
|
| 171 |
+
else:
|
| 172 |
+
# Fallback to simple response
|
| 173 |
+
if search_results:
|
| 174 |
+
response = f"Based on the Rajasthan government documents, I found information about {query}. However, voice processing is currently limited. Please use text chat for detailed responses."
|
| 175 |
+
else:
|
| 176 |
+
response = f"I received your query about '{query}' but couldn't find specific documents. Please try using text chat for better results."
|
| 177 |
+
|
| 178 |
+
# Cache fallback response too
|
| 179 |
+
response_cache[cache_key] = response
|
| 180 |
+
return response
|
| 181 |
+
|
| 182 |
+
except Exception as e:
|
| 183 |
+
logger.error(f"LLM error: {e}")
|
| 184 |
+
return "I'm having trouble processing your request. Please try using the text chat."
|
| 185 |
+
|
| 186 |
+
@app.websocket("/ws/stream")
|
| 187 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 188 |
+
await websocket.accept()
|
| 189 |
+
logger.info("๐ WebSocket client connected")
|
| 190 |
+
|
| 191 |
+
# Store user session info
|
| 192 |
+
user_language = "english" # Default language
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
# Send initial greeting
|
| 196 |
+
await websocket.send_json({
|
| 197 |
+
"type": "connection_successful",
|
| 198 |
+
"message": "Hello! I'm your Rajasthan government document assistant. I can help with text and voice queries about pension schemes and government policies."
|
| 199 |
+
})
|
| 200 |
+
|
| 201 |
+
while True:
|
| 202 |
+
try:
|
| 203 |
+
# Receive message with better error handling
|
| 204 |
+
message = await websocket.receive()
|
| 205 |
+
|
| 206 |
+
# Handle different message types
|
| 207 |
+
if message["type"] == "websocket.receive":
|
| 208 |
+
if "text" in message:
|
| 209 |
+
# Parse JSON text message
|
| 210 |
+
try:
|
| 211 |
+
data = json.loads(message["text"])
|
| 212 |
+
except json.JSONDecodeError:
|
| 213 |
+
logger.warning(f"โ ๏ธ Invalid JSON received: {message['text']}")
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
# Process text message
|
| 217 |
+
if isinstance(data, dict) and data.get("type") == "text_message":
|
| 218 |
+
user_message = data.get("message", "")
|
| 219 |
+
if not user_message.strip():
|
| 220 |
+
continue
|
| 221 |
+
|
| 222 |
+
logger.info(f"๐ฌ Text received: {user_message}")
|
| 223 |
+
|
| 224 |
+
# Check for interactive scenario form triggers
|
| 225 |
+
form_triggers = ["start scenario analysis", "scenario form", "interactive analysis", "step by step analysis", "guided analysis", "form analysis", "scenario chat form", "interactive scenario"]
|
| 226 |
+
is_form_request = any(trigger in user_message.lower() for trigger in form_triggers)
|
| 227 |
+
|
| 228 |
+
# Check if this is a policy simulation query (robust regex patterns)
|
| 229 |
+
import re
|
| 230 |
+
POLICY_PATTERNS = [
|
| 231 |
+
r"policy.*simulation|simulation.*policy",
|
| 232 |
+
r"policy.*scenario|scenario.*policy",
|
| 233 |
+
r"policy.*analysis|analysis.*policy",
|
| 234 |
+
r"pension.*simulation|simulation.*pension",
|
| 235 |
+
r"pension.*analysis|analysis.*pension",
|
| 236 |
+
r"pension.*scenario|scenario.*pension",
|
| 237 |
+
r"dearness.*relief|dr.*increase|dr.*adjustment",
|
| 238 |
+
r"dearness.*allowance|da.*increase|da.*adjustment",
|
| 239 |
+
r"minimum.*pension.*increase|increase.*minimum.*pension",
|
| 240 |
+
r"calculate.*pension|pension.*calculation",
|
| 241 |
+
r"impact.*dr|dr.*impact|impact.*da|da.*impact",
|
| 242 |
+
r"show.*impact.*da|show.*impact.*dr",
|
| 243 |
+
r"impact.*\d+.*da|impact.*\d+.*dr",
|
| 244 |
+
r"\d+.*da.*increase|da.*\d+.*increase",
|
| 245 |
+
r"\d+.*dr.*increase|dr.*\d+.*increase",
|
| 246 |
+
r"inflation.*adjustment|adjustment.*inflation",
|
| 247 |
+
r"scenario.*analysis|analysis.*scenario",
|
| 248 |
+
r"what.*if.*dr|what.*if.*pension|what.*if.*da",
|
| 249 |
+
r"compare.*scenario|scenario.*comparison",
|
| 250 |
+
r"show.*chart|chart.*show",
|
| 251 |
+
r"explain.*chart|chart.*explain",
|
| 252 |
+
r"using.*chart|chart.*using",
|
| 253 |
+
r"dr.*\d+.*increase|increase.*dr.*\d+",
|
| 254 |
+
r"da.*\d+.*increase|increase.*da.*\d+",
|
| 255 |
+
r"analyze.*minimum.*pension",
|
| 256 |
+
r"pension.*change",
|
| 257 |
+
r"make.*chart|chart.*make",
|
| 258 |
+
r"pension.*value|value.*pension",
|
| 259 |
+
r"basic.*pension.*\d+|pension.*\d+",
|
| 260 |
+
r"simulate.*dr|simulate.*pension|simulate.*da"
|
| 261 |
+
]
|
| 262 |
+
|
| 263 |
+
def is_policy_simulation_query(message: str) -> bool:
|
| 264 |
+
"""Check if the message is a policy simulation query"""
|
| 265 |
+
message_lower = message.lower()
|
| 266 |
+
logger.info(f"๐ Checking policy patterns for: '{message_lower}'")
|
| 267 |
+
|
| 268 |
+
for i, pattern in enumerate(POLICY_PATTERNS):
|
| 269 |
+
if re.search(pattern, message_lower, re.IGNORECASE):
|
| 270 |
+
logger.info(f"โ
Pattern {i+1} matched: {pattern}")
|
| 271 |
+
return True
|
| 272 |
+
|
| 273 |
+
logger.info("โ No policy patterns matched")
|
| 274 |
+
return False
|
| 275 |
+
|
| 276 |
+
is_policy_query = is_policy_simulation_query(user_message)
|
| 277 |
+
|
| 278 |
+
# Handle interactive scenario form request
|
| 279 |
+
if is_form_request:
|
| 280 |
+
logger.info("๐ Interactive scenario form requested")
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
from scenario_chat_form import start_scenario_analysis_form
|
| 284 |
+
form_response = start_scenario_analysis_form(data.get("user_id", "default"))
|
| 285 |
+
|
| 286 |
+
# Format form response for chat
|
| 287 |
+
form_message = f"""๐ฏ **{form_response.get('title', 'Interactive Scenario Analysis')}**
|
| 288 |
+
|
| 289 |
+
{form_response.get('message', '')}
|
| 290 |
+
|
| 291 |
+
**{form_response.get('step_title', 'Step 1')}** ({form_response.get('current_step', 1)}/{form_response.get('total_steps', 4)})
|
| 292 |
+
|
| 293 |
+
{form_response['form_data']['question']}
|
| 294 |
+
|
| 295 |
+
**Available Options:**"""
|
| 296 |
+
|
| 297 |
+
# Add form options
|
| 298 |
+
if form_response['form_data']['input_type'] == 'select':
|
| 299 |
+
for i, option in enumerate(form_response['form_data']['options'], 1):
|
| 300 |
+
form_message += f"\n{i}. {option['label']}"
|
| 301 |
+
|
| 302 |
+
form_message += "\n\n**Quick Actions:**"
|
| 303 |
+
for action in form_response.get('quick_actions', []):
|
| 304 |
+
form_message += f"\nโข {action['text']}"
|
| 305 |
+
|
| 306 |
+
form_message += "\n\n๐ก **Next:** Choose an option above or type your selection!"
|
| 307 |
+
|
| 308 |
+
await websocket.send_json({
|
| 309 |
+
"type": "interactive_form",
|
| 310 |
+
"message": form_message,
|
| 311 |
+
"form_data": form_response
|
| 312 |
+
})
|
| 313 |
+
continue
|
| 314 |
+
|
| 315 |
+
except Exception as e:
|
| 316 |
+
logger.error(f"Form initialization failed: {str(e)}")
|
| 317 |
+
await websocket.send_json({
|
| 318 |
+
"type": "error_message",
|
| 319 |
+
"message": f"Sorry, I couldn't start the interactive scenario analysis. Error: {str(e)}"
|
| 320 |
+
})
|
| 321 |
+
continue
|
| 322 |
+
|
| 323 |
+
# Handle policy queries
|
| 324 |
+
elif is_policy_query:
|
| 325 |
+
logger.info("๐ฏ Detected policy simulation query")
|
| 326 |
+
|
| 327 |
+
try:
|
| 328 |
+
# Import policy chat interface
|
| 329 |
+
from policy_chat_interface import PolicySimulatorChatInterface
|
| 330 |
+
|
| 331 |
+
# Send acknowledgment for policy simulation
|
| 332 |
+
await websocket.send_json({
|
| 333 |
+
"type": "message_received",
|
| 334 |
+
"message": "๐ฏ Analyzing Rajasthan policy impact..."
|
| 335 |
+
})
|
| 336 |
+
|
| 337 |
+
# Initialize and process policy simulation
|
| 338 |
+
policy_simulator = PolicySimulatorChatInterface()
|
| 339 |
+
policy_result = policy_simulator.process_policy_query(user_message)
|
| 340 |
+
|
| 341 |
+
# Format policy response - use same format as working simple backend
|
| 342 |
+
if policy_result.get("type") == "policy_simulation":
|
| 343 |
+
# Serialize the response for JSON
|
| 344 |
+
serialized_response = serialize_for_json(policy_result)
|
| 345 |
+
|
| 346 |
+
# Send policy simulation response
|
| 347 |
+
await websocket.send_json({
|
| 348 |
+
"type": "policy_simulation",
|
| 349 |
+
"data": serialized_response
|
| 350 |
+
})
|
| 351 |
+
logger.info("๐ค Policy simulation response sent to client")
|
| 352 |
+
else:
|
| 353 |
+
# Handle other policy responses (errors, help, etc.)
|
| 354 |
+
await websocket.send_json({
|
| 355 |
+
"type": "text_response",
|
| 356 |
+
"message": policy_result.get('message', 'Policy analysis completed')
|
| 357 |
+
})
|
| 358 |
+
|
| 359 |
+
continue
|
| 360 |
+
|
| 361 |
+
except Exception as e:
|
| 362 |
+
logger.error(f"Policy simulation failed: {str(e)}")
|
| 363 |
+
await websocket.send_json({
|
| 364 |
+
"type": "error_message",
|
| 365 |
+
"message": f"Sorry, policy analysis failed. Using document search instead."
|
| 366 |
+
})
|
| 367 |
+
# Fall through to regular document search
|
| 368 |
+
|
| 369 |
+
# Regular document search (fallback)
|
| 370 |
+
# Send acknowledgment
|
| 371 |
+
await websocket.send_json({
|
| 372 |
+
"type": "message_received",
|
| 373 |
+
"message": "๐ Searching Rajasthan government documents..."
|
| 374 |
+
})
|
| 375 |
+
|
| 376 |
+
# Search for relevant documents
|
| 377 |
+
search_results, source = search_documents_simple(user_message)
|
| 378 |
+
logger.info(f"๐ Found {len(search_results)} documents from {source}")
|
| 379 |
+
|
| 380 |
+
# Get LLM response
|
| 381 |
+
llm_response = await get_llm_response(user_message, search_results)
|
| 382 |
+
|
| 383 |
+
# Send response
|
| 384 |
+
await websocket.send_json({
|
| 385 |
+
"type": "text_response",
|
| 386 |
+
"message": llm_response
|
| 387 |
+
})
|
| 388 |
+
|
| 389 |
+
elif isinstance(data, dict) and data.get("type") == "user_info":
|
| 390 |
+
user_name = data.get("user_name", "Unknown")
|
| 391 |
+
logger.info(f"๐ค User connected: {user_name}")
|
| 392 |
+
|
| 393 |
+
elif isinstance(data, dict) and data.get("lang"):
|
| 394 |
+
new_language = data.get("lang", "english")
|
| 395 |
+
if new_language != user_language:
|
| 396 |
+
user_language = new_language
|
| 397 |
+
logger.info(f"๐ Language preference updated: {user_language}")
|
| 398 |
+
# Avoid logging if language hasn't changed
|
| 399 |
+
|
| 400 |
+
elif "bytes" in message:
|
| 401 |
+
# Handle binary message (audio data)
|
| 402 |
+
audio_data = message["bytes"]
|
| 403 |
+
logger.info(f"๐ค Received audio data: {len(audio_data)} bytes")
|
| 404 |
+
|
| 405 |
+
if VOICE_AVAILABLE and voice_service:
|
| 406 |
+
try:
|
| 407 |
+
# Save audio data to temporary file for processing
|
| 408 |
+
import tempfile
|
| 409 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
| 410 |
+
temp_file.write(audio_data)
|
| 411 |
+
temp_file_path = temp_file.name
|
| 412 |
+
|
| 413 |
+
# Process audio with voice service using user's language preference
|
| 414 |
+
text = await voice_service.speech_to_text(temp_file_path, user_language)
|
| 415 |
+
|
| 416 |
+
# Clean up temp file
|
| 417 |
+
os.unlink(temp_file_path)
|
| 418 |
+
|
| 419 |
+
if text and text.strip():
|
| 420 |
+
logger.info(f"๐ค Transcribed: {text}")
|
| 421 |
+
|
| 422 |
+
# Search documents
|
| 423 |
+
search_results, source = search_documents_simple(text)
|
| 424 |
+
logger.info(f"๐ Found {len(search_results)} documents from {source}")
|
| 425 |
+
|
| 426 |
+
# Get LLM response
|
| 427 |
+
llm_response = await get_llm_response(text, search_results)
|
| 428 |
+
|
| 429 |
+
# Send text response
|
| 430 |
+
await websocket.send_json({
|
| 431 |
+
"type": "text_response",
|
| 432 |
+
"message": llm_response
|
| 433 |
+
})
|
| 434 |
+
|
| 435 |
+
# Try to send voice response
|
| 436 |
+
try:
|
| 437 |
+
audio_response = await voice_service.text_to_speech(llm_response)
|
| 438 |
+
if audio_response:
|
| 439 |
+
await websocket.send_bytes(audio_response)
|
| 440 |
+
except Exception as tts_error:
|
| 441 |
+
logger.warning(f"TTS failed: {tts_error}")
|
| 442 |
+
else:
|
| 443 |
+
await websocket.send_json({
|
| 444 |
+
"type": "text_response",
|
| 445 |
+
"message": "I couldn't understand what you said. Please try speaking more clearly or use text chat."
|
| 446 |
+
})
|
| 447 |
+
|
| 448 |
+
except Exception as voice_error:
|
| 449 |
+
logger.error(f"Voice processing error: {voice_error}")
|
| 450 |
+
await websocket.send_json({
|
| 451 |
+
"type": "text_response",
|
| 452 |
+
"message": "Sorry, I couldn't process your voice input. Please try speaking again or use text chat."
|
| 453 |
+
})
|
| 454 |
+
else:
|
| 455 |
+
# Voice services not available
|
| 456 |
+
await websocket.send_json({
|
| 457 |
+
"type": "text_response",
|
| 458 |
+
"message": "Voice processing is currently unavailable. Please use the text chat to ask about Rajasthan pension schemes and government policies."
|
| 459 |
+
})
|
| 460 |
+
|
| 461 |
+
elif message["type"] == "websocket.disconnect":
|
| 462 |
+
break
|
| 463 |
+
|
| 464 |
+
except json.JSONDecodeError as e:
|
| 465 |
+
logger.warning(f"โ ๏ธ JSON decode error: {e}")
|
| 466 |
+
continue
|
| 467 |
+
except KeyError as e:
|
| 468 |
+
logger.warning(f"โ ๏ธ Missing key in message: {e}")
|
| 469 |
+
continue
|
| 470 |
+
|
| 471 |
+
except WebSocketDisconnect:
|
| 472 |
+
logger.info("๐ WebSocket client disconnected")
|
| 473 |
+
except Exception as e:
|
| 474 |
+
logger.error(f"โ WebSocket error: {e}")
|
| 475 |
+
|
| 476 |
+
@app.get("/health")
|
| 477 |
+
async def health_check():
|
| 478 |
+
"""Health check endpoint"""
|
| 479 |
+
try:
|
| 480 |
+
db = lancedb.connect('./lancedb_data')
|
| 481 |
+
tables = db.table_names()
|
| 482 |
+
return {
|
| 483 |
+
"status": "healthy",
|
| 484 |
+
"tables": tables,
|
| 485 |
+
"voice_available": VOICE_AVAILABLE
|
| 486 |
+
}
|
| 487 |
+
except Exception as e:
|
| 488 |
+
return {"status": "error", "error": str(e)}
|
| 489 |
+
|
| 490 |
+
if __name__ == "__main__":
|
| 491 |
+
print("๐ Starting voice-enabled WebSocket server on port 8000...")
|
| 492 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
websocket_handler.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import WebSocket, WebSocketDisconnect
|
| 2 |
+
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
| 3 |
+
import logging
|
| 4 |
+
import json
|
| 5 |
+
import asyncio
|
| 6 |
+
import re
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
from hybrid_llm_service import HybridLLMService # Fixed import
|
| 9 |
+
from voice_service import VoiceService
|
| 10 |
+
from rag_service import search_documents
|
| 11 |
+
from llm_service import create_graph, create_basic_graph
|
| 12 |
+
from lancedb_service import lancedb_service
|
| 13 |
+
from policy_chat_interface import PolicySimulatorChatInterface
|
| 14 |
+
|
| 15 |
+
# Configure logging
|
| 16 |
+
logging.basicConfig(level=logging.INFO)
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
# Initialize services
|
| 20 |
+
hybrid_llm_service = HybridLLMService() # Create instance
|
| 21 |
+
voice_service = VoiceService()
|
| 22 |
+
policy_simulator = PolicySimulatorChatInterface()
|
| 23 |
+
|
| 24 |
+
# Policy simulation detection patterns
|
| 25 |
+
POLICY_PATTERNS = [
|
| 26 |
+
r"scenario.*analy",
|
| 27 |
+
r"policy.*simulat",
|
| 28 |
+
r"pension.*analy",
|
| 29 |
+
r"simulate.*dr|dr.*simulat",
|
| 30 |
+
r"simulate.*pension|pension.*simulat",
|
| 31 |
+
r"impact.*analy",
|
| 32 |
+
r"dearness.*relief",
|
| 33 |
+
r"basic.*pension",
|
| 34 |
+
r"medical.*allowance",
|
| 35 |
+
r"chart.*pension|pension.*chart",
|
| 36 |
+
r"visual.*analy|analy.*visual",
|
| 37 |
+
r"show.*chart|chart.*show",
|
| 38 |
+
r"explain.*chart|chart.*explain",
|
| 39 |
+
r"using.*chart|chart.*using",
|
| 40 |
+
r"dr.*\d+.*increase|increase.*dr.*\d+",
|
| 41 |
+
r"analyze.*minimum.*pension",
|
| 42 |
+
r"pension.*change"
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
def is_policy_simulation_query(message: str) -> bool:
|
| 46 |
+
"""Check if the message is a policy simulation query"""
|
| 47 |
+
message_lower = message.lower()
|
| 48 |
+
return any(re.search(pattern, message_lower, re.IGNORECASE) for pattern in POLICY_PATTERNS)
|
| 49 |
+
|
| 50 |
+
async def handle_websocket_connection(websocket: WebSocket):
|
| 51 |
+
"""Handle WebSocket connection for the voice bot"""
|
| 52 |
+
await websocket.accept()
|
| 53 |
+
logger.info("๐ WebSocket client connected.")
|
| 54 |
+
|
| 55 |
+
import uuid
|
| 56 |
+
|
| 57 |
+
initial_data = await websocket.receive_json()
|
| 58 |
+
messages = []
|
| 59 |
+
|
| 60 |
+
# Check if user authentication is provided
|
| 61 |
+
flag = "user_id" in initial_data
|
| 62 |
+
if flag:
|
| 63 |
+
thread_id = initial_data.get("user_id")
|
| 64 |
+
knowledge_base = initial_data.get("knowledge_base", "government_docs")
|
| 65 |
+
|
| 66 |
+
# Create graph with RAG capabilities
|
| 67 |
+
graph = await create_graph(kb_tool=True, mcp_config=None)
|
| 68 |
+
|
| 69 |
+
config = {
|
| 70 |
+
"configurable": {
|
| 71 |
+
"thread_id": thread_id,
|
| 72 |
+
"knowledge_base": knowledge_base,
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
# Set system prompt for government document queries
|
| 77 |
+
system_message = """You are a helpful assistant that can answer questions about government documents, policies, and procedures.
|
| 78 |
+
Keep your responses clear and concise. When referencing specific documents or policies, mention the source.
|
| 79 |
+
If you're uncertain about information, clearly state that and suggest where the user might find authoritative information."""
|
| 80 |
+
|
| 81 |
+
messages.append(SystemMessage(content=system_message))
|
| 82 |
+
else:
|
| 83 |
+
# Basic graph for unauthenticated users
|
| 84 |
+
graph = create_basic_graph()
|
| 85 |
+
thread_id = str(uuid.uuid4())
|
| 86 |
+
config = {"configurable": {"thread_id": thread_id}}
|
| 87 |
+
|
| 88 |
+
# Send initial greeting
|
| 89 |
+
greeting_message = HumanMessage(
|
| 90 |
+
content="Generate a brief greeting for the user, introduce yourself as a government document assistant, and explain how you can help them find information from government policies and documents."
|
| 91 |
+
)
|
| 92 |
+
messages.append(greeting_message)
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
response = await graph.ainvoke({"messages": messages}, config=config)
|
| 96 |
+
greeting_response = response["messages"][-1].content
|
| 97 |
+
messages.append(AIMessage(content=greeting_response))
|
| 98 |
+
|
| 99 |
+
await websocket.send_json({
|
| 100 |
+
"type": "connection_successful",
|
| 101 |
+
"message": greeting_response
|
| 102 |
+
})
|
| 103 |
+
except Exception as e:
|
| 104 |
+
logger.error(f"โ Error generating greeting: {e}")
|
| 105 |
+
await websocket.send_json({
|
| 106 |
+
"type": "connection_successful",
|
| 107 |
+
"message": "Hello! I'm your government document assistant. How can I help you today?"
|
| 108 |
+
})
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
while True:
|
| 112 |
+
data = await websocket.receive_json()
|
| 113 |
+
|
| 114 |
+
if data["type"] == "text_message":
|
| 115 |
+
# Handle text message
|
| 116 |
+
user_message = data["message"]
|
| 117 |
+
logger.info(f"๐ฌ Received text message: {user_message}")
|
| 118 |
+
messages.append(HumanMessage(content=user_message))
|
| 119 |
+
|
| 120 |
+
# Send acknowledgment
|
| 121 |
+
await websocket.send_json({
|
| 122 |
+
"type": "message_received",
|
| 123 |
+
"message": "Processing your message..."
|
| 124 |
+
})
|
| 125 |
+
|
| 126 |
+
# Check if this is a policy simulation query
|
| 127 |
+
if is_policy_simulation_query(user_message):
|
| 128 |
+
logger.info("๐ฏ Detected policy simulation query")
|
| 129 |
+
try:
|
| 130 |
+
# Process with policy simulator
|
| 131 |
+
policy_response = policy_simulator.process_policy_query(user_message)
|
| 132 |
+
|
| 133 |
+
# Send policy simulation response
|
| 134 |
+
await websocket.send_json({
|
| 135 |
+
"type": "policy_simulation",
|
| 136 |
+
"data": policy_response
|
| 137 |
+
})
|
| 138 |
+
|
| 139 |
+
messages.append(AIMessage(content=policy_response.get('message', 'Policy simulation completed')))
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
except Exception as policy_error:
|
| 143 |
+
logger.error(f"โ Policy simulation failed: {policy_error}")
|
| 144 |
+
# Fall back to normal processing
|
| 145 |
+
|
| 146 |
+
# First try to search for relevant documents
|
| 147 |
+
search_results = None
|
| 148 |
+
try:
|
| 149 |
+
# Search for documents related to the user's query
|
| 150 |
+
search_results = search_documents(user_message, limit=5)
|
| 151 |
+
logger.info(f"๐ Found {len(search_results) if search_results else 0} documents for query")
|
| 152 |
+
except Exception as search_error:
|
| 153 |
+
logger.warning(f"โ ๏ธ Document search failed: {search_error}")
|
| 154 |
+
|
| 155 |
+
# Get LLM response (with or without search context)
|
| 156 |
+
try:
|
| 157 |
+
if search_results and len(search_results) > 0:
|
| 158 |
+
# Add search context to the message
|
| 159 |
+
context_message = f"User query: {user_message}\n\nRelevant documents found:\n"
|
| 160 |
+
for i, doc in enumerate(search_results[:3], 1):
|
| 161 |
+
context_message += f"\n{i}. Source: {doc.get('filename', 'Unknown')}\nContent: {doc.get('content', '')[:400]}...\n"
|
| 162 |
+
|
| 163 |
+
context_message += f"\nBased on the above documents, please provide a helpful response to the user's query: {user_message}"
|
| 164 |
+
|
| 165 |
+
# Replace the user message with the enriched version
|
| 166 |
+
messages[-1] = HumanMessage(content=context_message)
|
| 167 |
+
|
| 168 |
+
result = await graph.ainvoke({"messages": messages}, config=config)
|
| 169 |
+
llm_response = result["messages"][-1].content
|
| 170 |
+
|
| 171 |
+
# Check if response contains scenario analysis images
|
| 172 |
+
if "**SCENARIO_IMAGES_START**" in llm_response and "**SCENARIO_IMAGES_END**" in llm_response:
|
| 173 |
+
# Extract images and text separately
|
| 174 |
+
parts = llm_response.split("**SCENARIO_IMAGES_START**")
|
| 175 |
+
text_response = parts[0].strip()
|
| 176 |
+
|
| 177 |
+
image_part = parts[1].split("**SCENARIO_IMAGES_END**")[0].strip()
|
| 178 |
+
|
| 179 |
+
try:
|
| 180 |
+
import json
|
| 181 |
+
images = json.loads(image_part)
|
| 182 |
+
|
| 183 |
+
# Send text response first
|
| 184 |
+
await websocket.send_json({
|
| 185 |
+
"type": "text_response",
|
| 186 |
+
"message": text_response
|
| 187 |
+
})
|
| 188 |
+
|
| 189 |
+
# Send images separately
|
| 190 |
+
await websocket.send_json({
|
| 191 |
+
"type": "scenario_images",
|
| 192 |
+
"images": images
|
| 193 |
+
})
|
| 194 |
+
|
| 195 |
+
except json.JSONDecodeError:
|
| 196 |
+
# If JSON parsing fails, send as regular text
|
| 197 |
+
await websocket.send_json({
|
| 198 |
+
"type": "text_response",
|
| 199 |
+
"message": llm_response
|
| 200 |
+
})
|
| 201 |
+
else:
|
| 202 |
+
# Send regular text response
|
| 203 |
+
await websocket.send_json({
|
| 204 |
+
"type": "text_response",
|
| 205 |
+
"message": llm_response
|
| 206 |
+
})
|
| 207 |
+
|
| 208 |
+
# Add AI response to messages
|
| 209 |
+
messages.append(AIMessage(content=llm_response))
|
| 210 |
+
|
| 211 |
+
logger.info(f"โ
Sent response to user: {thread_id}")
|
| 212 |
+
|
| 213 |
+
except Exception as e:
|
| 214 |
+
logger.error(f"โ Error processing message: {e}")
|
| 215 |
+
await websocket.send_json({
|
| 216 |
+
"type": "error",
|
| 217 |
+
"message": "Sorry, I encountered an error processing your message."
|
| 218 |
+
})
|
| 219 |
+
|
| 220 |
+
elif data["type"] == "ping":
|
| 221 |
+
# Handle ping for connection keep-alive
|
| 222 |
+
await websocket.send_json({"type": "pong"})
|
| 223 |
+
|
| 224 |
+
elif data["type"] == "get_knowledge_bases":
|
| 225 |
+
# Send available knowledge bases
|
| 226 |
+
try:
|
| 227 |
+
kb_list = await lancedb_service.get_knowledge_bases()
|
| 228 |
+
await websocket.send_json({
|
| 229 |
+
"type": "knowledge_bases",
|
| 230 |
+
"knowledge_bases": kb_list
|
| 231 |
+
})
|
| 232 |
+
except Exception as e:
|
| 233 |
+
logger.error(f"โ Error getting knowledge bases: {e}")
|
| 234 |
+
await websocket.send_json({
|
| 235 |
+
"type": "error",
|
| 236 |
+
"message": "Error retrieving knowledge bases"
|
| 237 |
+
})
|
| 238 |
+
|
| 239 |
+
elif data["type"] == "end_session":
|
| 240 |
+
logger.info("๐ Session ended by client")
|
| 241 |
+
await websocket.close()
|
| 242 |
+
break
|
| 243 |
+
|
| 244 |
+
except WebSocketDisconnect:
|
| 245 |
+
logger.info("๐ WebSocket client disconnected.")
|
| 246 |
+
except Exception as e:
|
| 247 |
+
logger.error(f"โ WebSocket error: {e}")
|
| 248 |
+
try:
|
| 249 |
+
await websocket.send_json({
|
| 250 |
+
"type": "error",
|
| 251 |
+
"message": "Connection error occurred"
|
| 252 |
+
})
|
| 253 |
+
except:
|
| 254 |
+
pass
|
| 255 |
+
finally:
|
| 256 |
+
# Clean up when session ends
|
| 257 |
+
logger.info(f"๐ Session {thread_id} ended")
|
| 258 |
+
|
| 259 |
+
async def send_welcome_message(websocket: WebSocket):
|
| 260 |
+
"""Send welcome message to the client"""
|
| 261 |
+
try:
|
| 262 |
+
welcome_text = """๐ฎ๐ณ Welcome to the Government Services AI Assistant!
|
| 263 |
+
|
| 264 |
+
I'm here to help you with:
|
| 265 |
+
โข Government policies and procedures
|
| 266 |
+
โข Document information and guidance
|
| 267 |
+
โข Service-specific questions and redirects
|
| 268 |
+
โข Voice or text interaction (your choice!)
|
| 269 |
+
|
| 270 |
+
How can I assist you today?"""
|
| 271 |
+
|
| 272 |
+
await websocket.send_text(json.dumps({
|
| 273 |
+
"type": "bot_message",
|
| 274 |
+
"content": welcome_text,
|
| 275 |
+
"timestamp": asyncio.get_event_loop().time()
|
| 276 |
+
}))
|
| 277 |
+
|
| 278 |
+
except Exception as e:
|
| 279 |
+
logger.error(f"โ Error sending welcome message: {e}")
|
| 280 |
+
|
| 281 |
+
async def handle_text_message(websocket: WebSocket, message_data: Dict[str, Any]):
|
| 282 |
+
"""Handle text-based messages"""
|
| 283 |
+
try:
|
| 284 |
+
user_message = message_data.get("content", "")
|
| 285 |
+
logger.info(f"๐ฌ Processing text message: {user_message}")
|
| 286 |
+
|
| 287 |
+
# Search for relevant documents
|
| 288 |
+
context = ""
|
| 289 |
+
try:
|
| 290 |
+
search_results = search_documents(user_message, limit=3)
|
| 291 |
+
if search_results:
|
| 292 |
+
context = "\n".join([doc.get("content", "") for doc in search_results])
|
| 293 |
+
logger.info(f"๐ Found {len(search_results)} relevant documents")
|
| 294 |
+
except Exception as e:
|
| 295 |
+
logger.warning(f"โ ๏ธ Document search failed: {e}")
|
| 296 |
+
|
| 297 |
+
# Get response from hybrid LLM
|
| 298 |
+
response_text = ""
|
| 299 |
+
try:
|
| 300 |
+
# Check if this is a streaming request
|
| 301 |
+
stream_response = message_data.get("stream", True)
|
| 302 |
+
|
| 303 |
+
if stream_response:
|
| 304 |
+
# Send streaming response
|
| 305 |
+
await websocket.send_text(json.dumps({
|
| 306 |
+
"type": "bot_message_start",
|
| 307 |
+
"timestamp": asyncio.get_event_loop().time()
|
| 308 |
+
}))
|
| 309 |
+
|
| 310 |
+
async for chunk in hybrid_llm_service.get_streaming_response(user_message, context):
|
| 311 |
+
response_text += chunk
|
| 312 |
+
await websocket.send_text(json.dumps({
|
| 313 |
+
"type": "bot_message_chunk",
|
| 314 |
+
"content": chunk,
|
| 315 |
+
"timestamp": asyncio.get_event_loop().time()
|
| 316 |
+
}))
|
| 317 |
+
await asyncio.sleep(0.01) # Small delay for better streaming
|
| 318 |
+
|
| 319 |
+
await websocket.send_text(json.dumps({
|
| 320 |
+
"type": "bot_message_end",
|
| 321 |
+
"timestamp": asyncio.get_event_loop().time()
|
| 322 |
+
}))
|
| 323 |
+
else:
|
| 324 |
+
# Send complete response
|
| 325 |
+
response_text = await hybrid_llm_service.get_response(user_message, context)
|
| 326 |
+
await websocket.send_text(json.dumps({
|
| 327 |
+
"type": "bot_message",
|
| 328 |
+
"content": response_text,
|
| 329 |
+
"timestamp": asyncio.get_event_loop().time()
|
| 330 |
+
}))
|
| 331 |
+
|
| 332 |
+
except Exception as e:
|
| 333 |
+
logger.error(f"โ Error getting LLM response: {e}")
|
| 334 |
+
await websocket.send_text(json.dumps({
|
| 335 |
+
"type": "bot_message",
|
| 336 |
+
"content": f"I apologize, but I encountered an error processing your request: {str(e)}",
|
| 337 |
+
"timestamp": asyncio.get_event_loop().time()
|
| 338 |
+
}))
|
| 339 |
+
|
| 340 |
+
# Add government service redirect suggestions
|
| 341 |
+
try:
|
| 342 |
+
redirect_suggestions = voice_service.generate_redirect_suggestions(user_message, "text")
|
| 343 |
+
if redirect_suggestions:
|
| 344 |
+
await websocket.send_text(json.dumps({
|
| 345 |
+
"type": "redirect_suggestions",
|
| 346 |
+
"content": redirect_suggestions,
|
| 347 |
+
"timestamp": asyncio.get_event_loop().time()
|
| 348 |
+
}))
|
| 349 |
+
except Exception as e:
|
| 350 |
+
logger.warning(f"โ ๏ธ Could not generate redirect suggestions: {e}")
|
| 351 |
+
|
| 352 |
+
except Exception as e:
|
| 353 |
+
logger.error(f"โ Error handling text message: {e}")
|
| 354 |
+
await websocket.send_text(json.dumps({
|
| 355 |
+
"type": "error",
|
| 356 |
+
"content": f"Error processing your message: {str(e)}"
|
| 357 |
+
}))
|
| 358 |
+
|
| 359 |
+
async def handle_voice_message(websocket: WebSocket, message_data: Dict[str, Any]):
|
| 360 |
+
"""Handle voice-based messages"""
|
| 361 |
+
try:
|
| 362 |
+
# Check if voice features are enabled
|
| 363 |
+
if not voice_service.voice_enabled:
|
| 364 |
+
await websocket.send_text(json.dumps({
|
| 365 |
+
"type": "error",
|
| 366 |
+
"content": "Voice features are currently disabled. Please use text input."
|
| 367 |
+
}))
|
| 368 |
+
return
|
| 369 |
+
|
| 370 |
+
audio_data = message_data.get("audio_data", "")
|
| 371 |
+
if not audio_data:
|
| 372 |
+
await websocket.send_text(json.dumps({
|
| 373 |
+
"type": "error",
|
| 374 |
+
"content": "No audio data received"
|
| 375 |
+
}))
|
| 376 |
+
return
|
| 377 |
+
|
| 378 |
+
logger.info("๐ค Processing voice message")
|
| 379 |
+
|
| 380 |
+
# Convert speech to text
|
| 381 |
+
try:
|
| 382 |
+
transcribed_text = await voice_service.speech_to_text(audio_data)
|
| 383 |
+
logger.info(f"๐ Transcribed: {transcribed_text}")
|
| 384 |
+
|
| 385 |
+
# Send transcription to client
|
| 386 |
+
await websocket.send_text(json.dumps({
|
| 387 |
+
"type": "transcription",
|
| 388 |
+
"content": transcribed_text,
|
| 389 |
+
"timestamp": asyncio.get_event_loop().time()
|
| 390 |
+
}))
|
| 391 |
+
|
| 392 |
+
except Exception as e:
|
| 393 |
+
logger.error(f"โ Speech-to-text failed: {e}")
|
| 394 |
+
await websocket.send_text(json.dumps({
|
| 395 |
+
"type": "error",
|
| 396 |
+
"content": f"Speech recognition failed: {str(e)}"
|
| 397 |
+
}))
|
| 398 |
+
except Exception as e:
|
| 399 |
+
logger.error(f"โ Error handling voice message: {e}")
|
| 400 |
+
await websocket.send_text(json.dumps({
|
| 401 |
+
"type": "error",
|
| 402 |
+
"content": f"Error processing voice message: {str(e)}"
|
| 403 |
+
}))
|