Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced WebSocket handler with hybrid LLM and voice features | |
| """ | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| from langchain_core.messages import HumanMessage, SystemMessage, AIMessage | |
| import logging | |
| import json | |
| import asyncio | |
| import uuid | |
| import tempfile | |
| import base64 | |
| from pathlib import Path | |
| import io | |
| import matplotlib.pyplot as plt | |
| from llm_service import create_graph, create_basic_graph | |
| from lancedb_service import lancedb_service | |
| from hybrid_llm_service import HybridLLMService | |
| from voice_service import voice_service | |
| from rag_service import search_government_docs | |
| from policy_chart_generator import PolicyChartGenerator | |
| from conversational_service import conversational_service | |
| # Initialize hybrid LLM service | |
| hybrid_llm_service = HybridLLMService() | |
| logger = logging.getLogger("voicebot") | |
| def analyze_query_context(query: str) -> dict: | |
| """Analyze query to determine if it's document-related or general, and identify user role""" | |
| query_lower = query.lower() | |
| # Role-specific keywords and queries | |
| role_patterns = { | |
| 'pension_beneficiary': [ | |
| 'pension eligibility', 'pension documents', 'pension application', 'retirement benefits', | |
| 'pension calculation', 'pension amount', 'family pension', 'commutation', | |
| 'gratuity eligibility', 'provident fund withdrawal', 'medical benefits after retirement', | |
| 'pension certificate', 'life certificate', 'pension arrears', 'how to apply pension', | |
| 'pension office', 'pension disbursement', 'pension inquiry', 'pension status' | |
| ], | |
| 'procurement_officer': [ | |
| 'tender process', 'bid submission', 'procurement thresholds', 'gem portal', | |
| 'msme relaxation', 'vendor registration', 'procurement checklist', 'bid evaluation', | |
| 'tender documents', 'procurement rules', 'bidding process', 'contract award', | |
| 'procurement guidelines', 'tender notice', 'technical bid', 'financial bid', | |
| 'procurement manual', 'vendor empanelment', 'tender committee' | |
| ], | |
| 'finance_staff': [ | |
| 'sanctioning authority', 'financial approval', 'budget allocation', 'expenditure sanction', | |
| 'financial registers', 'audit compliance', 'treasury rules', 'payment authorization', | |
| 'financial delegation', 'budget utilization', 'fund release', 'financial procedure', | |
| 'accounting rules', 'financial reporting', 'expenditure control', 'financial audit', | |
| 'cash book', 'voucher processing', 'financial clearance' | |
| ], | |
| 'leadership_policymaker': [ | |
| 'policy impact', 'scenario analysis', 'cost comparison', 'policy implementation', | |
| 'evidence pack', 'policy evaluation', 'impact assessment', 'strategic planning', | |
| 'policy formulation', 'comparative analysis', 'policy review', 'governance framework', | |
| 'administrative reform', 'policy effectiveness', 'decision support', 'policy brief' | |
| ] | |
| } | |
| # Government document keywords (expanded) | |
| doc_keywords = [ | |
| 'pension', 'leave', 'allowance', 'da', 'dearness', 'procurement', 'tender', | |
| 'medical', 'reimbursement', 'transfer', 'posting', 'promotion', 'service', | |
| 'rules', 'policy', 'government', 'circular', 'notification', 'benefits', | |
| 'gratuity', 'provident fund', 'retirement', 'salary', 'pay commission', | |
| 'eligibility', 'documents', 'application', 'process', 'approval', 'sanction', | |
| 'audit', 'finance', 'budget', 'expenditure', 'treasury', 'guidelines' | |
| ] | |
| # General conversation keywords | |
| general_keywords = [ | |
| 'hello', 'hi', 'thank you', 'thanks', 'goodbye', 'bye', 'help', | |
| 'how are you', 'what is your name', 'who are you', 'weather', | |
| 'time', 'date', 'joke', 'story', 'song', 'recipe', 'movie' | |
| ] | |
| # Detect user role | |
| detected_role = None | |
| role_confidence = 0.0 | |
| for role, patterns in role_patterns.items(): | |
| role_matches = sum(1 for pattern in patterns if pattern in query_lower) | |
| if role_matches > 0: | |
| current_confidence = min(role_matches * 0.4, 1.0) | |
| if current_confidence > role_confidence: | |
| detected_role = role | |
| role_confidence = current_confidence | |
| # Count general matches | |
| doc_matches = sum(1 for kw in doc_keywords if kw in query_lower) | |
| general_matches = sum(1 for kw in general_keywords if kw in query_lower) | |
| # Determine query type - FIXED: Be more aggressive about document searches | |
| if doc_matches > 0 or detected_role: | |
| query_type = "document_related" | |
| confidence = max(min(doc_matches * 0.3, 1.0), role_confidence) | |
| elif general_matches > 0 and doc_matches == 0: | |
| # Only treat as general if there are ZERO document keywords | |
| query_type = "general_conversation" | |
| confidence = min(general_matches * 0.4, 1.0) | |
| else: | |
| # DEFAULT to document search - this is a government document system | |
| query_type = "document_related" | |
| confidence = 0.5 # Higher confidence for document search by default | |
| return { | |
| "type": query_type, | |
| "confidence": confidence, | |
| "doc_keywords_found": doc_matches, | |
| "general_keywords_found": general_matches, | |
| "detected_role": detected_role, | |
| "role_confidence": role_confidence | |
| } | |
| async def generate_llm_fallback_response(user_message: str, query_context: dict) -> str: | |
| """Generate response using Groq/Gemini for out-of-context queries""" | |
| try: | |
| # Determine which LLM to use based on query complexity | |
| provider = hybrid_llm_service.choose_llm_provider(user_message) | |
| # Create role-aware system prompt | |
| detected_role = query_context.get("detected_role") | |
| if query_context.get("type") == "general_conversation": | |
| system_prompt = """You are a helpful assistant for a government document system. | |
| The user is asking a general question not related to government documents. | |
| Provide a friendly, helpful response and gently guide them to ask about government policies, | |
| pension rules, leave policies, procurement procedures, or other administrative matters if they need official information.""" | |
| elif detected_role == "pension_beneficiary": | |
| system_prompt = """You are an AI assistant specializing in government pension and retirement benefits. | |
| The user appears to be a pension beneficiary or claimant. Provide helpful information about pension eligibility, | |
| application processes, required documents, and procedures. Always remind them to verify information with | |
| the pension disbursing authority and consult official government sources for the most current rules.""" | |
| elif detected_role == "procurement_officer": | |
| system_prompt = """You are an AI assistant specializing in government procurement procedures. | |
| The user appears to be involved in procurement or bidding processes. Provide helpful information about | |
| tender procedures, MSME benefits, GeM portal usage, and procurement guidelines. Always remind them to | |
| follow current procurement rules and consult the latest government circulars.""" | |
| elif detected_role == "finance_staff": | |
| system_prompt = """You are an AI assistant specializing in government financial procedures. | |
| The user appears to be finance staff. Provide helpful information about sanctioning procedures, | |
| budget management, audit compliance, and treasury rules. Always remind them to follow current | |
| financial rules and consult with the accounts department for official procedures.""" | |
| elif detected_role == "leadership_policymaker": | |
| system_prompt = """You are an AI assistant specializing in policy analysis and decision support. | |
| The user appears to be in a leadership or policy-making role. Provide helpful information about | |
| policy impact analysis, evidence-based decision making, and strategic planning. Always recommend | |
| consulting with relevant departments and conducting proper stakeholder consultations.""" | |
| else: | |
| system_prompt = """You are an AI assistant for government document queries. | |
| The user asked about something that wasn't found in the document database. | |
| Provide helpful general information if you can, but always remind them that for official | |
| government policies and procedures, they should consult official sources or contact | |
| the relevant government office. Keep responses concise and professional.""" | |
| # Generate response using hybrid LLM service | |
| if provider: | |
| response = await hybrid_llm_service.generate_response( | |
| user_message, | |
| system_prompt=system_prompt, | |
| provider=provider | |
| ) | |
| logger.info(f"β Generated LLM fallback response using {provider.value}") | |
| return response | |
| else: | |
| logger.warning("β οΈ No LLM provider available") | |
| return "I understand your question, but I'm currently unable to access my AI capabilities. Please try again later or contact the relevant government office for official information." | |
| except Exception as e: | |
| logger.error(f"β Error generating LLM fallback response: {e}") | |
| return f"I apologize, but I encountered an error while processing your query: '{user_message}'. Please try rephrasing your question or contact the relevant authorities for assistance." | |
| def validate_transcription_quality(text: str, language: str) -> dict: | |
| """Validate transcription quality and provide suggestions""" | |
| if not text or not text.strip(): | |
| return { | |
| "score": 0.0, | |
| "level": "very_low", | |
| "suggestions": ["No speech detected", "Check microphone", "Speak closer to microphone"] | |
| } | |
| text_clean = text.strip() | |
| words = text_clean.split() | |
| # Quality indicators | |
| word_count = len(words) | |
| avg_word_length = sum(len(word) for word in words) / max(word_count, 1) | |
| has_meaningful_words = any(len(word) > 2 for word in words) | |
| # Check for garbled/nonsensical words (too many consonants, unusual patterns) | |
| garbled_words = 0 | |
| for word in words: | |
| word_clean = ''.join(c for c in word.lower() if c.isalpha()) | |
| if len(word_clean) > 3: | |
| consonants = sum(1 for c in word_clean if c not in 'aeiou') | |
| vowels = len(word_clean) - consonants | |
| if consonants > vowels * 2: # Too many consonants | |
| garbled_words += 1 | |
| garbled_ratio = garbled_words / max(word_count, 1) | |
| # Language-specific checks | |
| if language in ['en', 'hi-en']: | |
| # Check for common English/Hinglish patterns | |
| common_words = ['the', 'and', 'is', 'in', 'to', 'of', 'for', 'with', 'on', 'at', 'by', 'from', | |
| 'pension', 'government', 'policy', 'rules', 'what', 'how', 'why', 'when', 'where', | |
| 'benefits', 'allowance', 'service', 'employee', 'officer', 'department'] | |
| has_common_words = any(word.lower() in common_words for word in words) | |
| # Check for obvious nonsensical combinations | |
| nonsensical_patterns = ['benchern', 'trend rules', 'rinterpret', 'wht'] | |
| has_nonsensical = any(pattern in text_clean.lower() for pattern in nonsensical_patterns) | |
| else: | |
| has_common_words = True # Assume valid for other languages | |
| has_nonsensical = False | |
| # Calculate quality score | |
| score = 0.0 | |
| if word_count > 0: | |
| score += 0.2 | |
| if word_count >= 3: | |
| score += 0.2 | |
| if avg_word_length > 2: | |
| score += 0.2 | |
| if has_meaningful_words: | |
| score += 0.2 | |
| if has_common_words: | |
| score += 0.2 | |
| # Apply penalties | |
| if garbled_ratio > 0.3: # More than 30% garbled words | |
| score *= 0.3 | |
| elif garbled_ratio > 0.1: # More than 10% garbled words | |
| score *= 0.6 | |
| if has_nonsensical: | |
| score *= 0.2 | |
| if word_count < 2 or avg_word_length < 2: | |
| score *= 0.5 | |
| # Determine quality level and suggestions | |
| if score >= 0.7: | |
| level = "high" | |
| suggestions = [] | |
| elif score >= 0.4: | |
| level = "medium" | |
| suggestions = ["Speak a bit more clearly for better recognition"] | |
| elif score >= 0.2: | |
| level = "low" | |
| suggestions = ["Speak more clearly", "Try speaking slower", "Reduce background noise"] | |
| else: | |
| level = "very_low" | |
| suggestions = ["Audio quality is poor", "Speak closer to microphone", "Reduce background noise", "Try speaking more slowly and clearly"] | |
| return { | |
| "score": score, | |
| "level": level, | |
| "suggestions": suggestions, | |
| "garbled_ratio": garbled_ratio, | |
| "word_count": word_count | |
| } | |
| def create_language_context(user_language: str, normalized_language: str) -> str: | |
| """Create appropriate language context for LLM responses""" | |
| if not user_language: | |
| return "" | |
| lang_lower = user_language.lower() | |
| if lang_lower in ['hindi', 'hi', 'hi-in']: | |
| return " (User is speaking in Hindi. You may include relevant Hindi terms for government policies in India, especially for technical terms like 'ΰ€Έΰ€°ΰ€ΰ€Ύΰ€°ΰ₯ ΰ€¨ΰ₯ΰ€€ΰ€Ώ', 'ΰ€ͺΰ₯ΰ€ΰ€Άΰ€¨', 'ΰ€ΰ€€ΰ₯ΰ€€ΰ€Ύ' etc.)" | |
| elif lang_lower in ['hinglish', 'hi-en']: | |
| return " (User is speaking in Hinglish - Hindi-English mix. Feel free to use both languages naturally in your response, especially for government terminology.)" | |
| elif lang_lower in ['spanish', 'es']: | |
| return " (User is speaking in Spanish. Respond in Spanish if possible, or provide translations for key terms.)" | |
| elif lang_lower in ['french', 'fr']: | |
| return " (User is speaking in French. Respond in French if possible, or provide translations for key terms.)" | |
| elif lang_lower in ['arabic', 'ar']: | |
| return " (User is speaking in Arabic. Respond in Arabic if possible, or provide translations for key terms.)" | |
| elif lang_lower in ['chinese', 'zh']: | |
| return " (User is speaking in Chinese. Respond in Chinese if possible, or provide translations for key terms.)" | |
| elif lang_lower in ['japanese', 'ja']: | |
| return " (User is speaking in Japanese. Respond in Japanese if possible, or provide translations for key terms.)" | |
| elif lang_lower in ['english', 'en', 'en-us', 'en-in']: | |
| return " (User is speaking in English. Provide clear, professional responses.)" | |
| else: | |
| return f" (User language preference: {user_language}. Adapt response accordingly if possible.)" | |
| def select_voice_for_language(user_language: str, preferred_voice: str = None) -> str: | |
| """Select appropriate TTS voice based on user's language""" | |
| if preferred_voice: | |
| return preferred_voice | |
| if not user_language: | |
| return "en-US-AriaNeural" # Default | |
| lang_lower = user_language.lower() | |
| # Voice mapping for different languages | |
| voice_map = { | |
| 'hindi': 'hi-IN-SwaraNeural', | |
| 'hi': 'hi-IN-SwaraNeural', | |
| 'hi-in': 'hi-IN-SwaraNeural', | |
| 'hinglish': 'en-IN-NeerjaNeural', # Indian English for Hinglish | |
| 'hi-en': 'en-IN-NeerjaNeural', | |
| 'english': 'en-US-AriaNeural', | |
| 'en': 'en-US-AriaNeural', | |
| 'en-us': 'en-US-AriaNeural', | |
| 'en-in': 'en-IN-NeerjaNeural', | |
| 'spanish': 'es-ES-ElviraNeural', | |
| 'es': 'es-ES-ElviraNeural', | |
| 'french': 'fr-FR-DeniseNeural', | |
| 'fr': 'fr-FR-DeniseNeural', | |
| 'german': 'de-DE-KatjaNeural', | |
| 'de': 'de-DE-KatjaNeural', | |
| 'portuguese': 'pt-BR-FranciscaNeural', | |
| 'pt': 'pt-BR-FranciscaNeural', | |
| 'italian': 'it-IT-ElsaNeural', | |
| 'it': 'it-IT-ElsaNeural', | |
| 'russian': 'ru-RU-SvetlanaNeural', | |
| 'ru': 'ru-RU-SvetlanaNeural', | |
| 'chinese': 'zh-CN-XiaoxiaoNeural', | |
| 'zh': 'zh-CN-XiaoxiaoNeural', | |
| 'japanese': 'ja-JP-NanamiNeural', | |
| 'ja': 'ja-JP-NanamiNeural', | |
| 'arabic': 'ar-SA-ZariyahNeural', | |
| 'ar': 'ar-SA-ZariyahNeural' | |
| } | |
| return voice_map.get(lang_lower, 'en-US-AriaNeural') | |
| def attempt_transcription_correction(text: str, quality_info: dict) -> str: | |
| """Attempt to correct common transcription errors, especially for government terms""" | |
| if not text or quality_info.get('score', 1) > 0.6: | |
| return text # Don't correct if quality is already good | |
| text_lower = text.lower() | |
| corrected = text | |
| # Common government term corrections | |
| corrections = { | |
| # Pension-related corrections | |
| 'tension': 'pension', | |
| 'penshun': 'pension', | |
| 'penshan': 'pension', | |
| 'mention': 'pension', | |
| 'bruised': 'rules', | |
| 'bruce': 'rules', | |
| 'brews': 'rules', | |
| 'cruise': 'rules', | |
| # Policy-related corrections | |
| 'policy': 'policy', # Keep as is | |
| 'polity': 'policy', | |
| 'polly': 'policy', | |
| # Government-related corrections | |
| 'government': 'government', # Keep as is | |
| 'goverment': 'government', | |
| 'govermint': 'government', | |
| # Allowance corrections | |
| 'allowens': 'allowance', | |
| 'alowance': 'allowance', | |
| # Benefits corrections | |
| 'benifits': 'benefits', | |
| 'benefets': 'benefits', | |
| # Common words | |
| 'wat': 'what', | |
| 'wot': 'what', | |
| 'wen': 'when', | |
| 'were': 'where', | |
| 'haw': 'how', | |
| 'no': 'know', | |
| 'noe': 'know' | |
| } | |
| # Split into words and correct each | |
| words = corrected.split() | |
| corrected_words = [] | |
| for word in words: | |
| # Remove punctuation for matching | |
| clean_word = word.lower().strip('.,!?;:') | |
| # Check for corrections | |
| if clean_word in corrections and corrections[clean_word] != clean_word: | |
| # Preserve original capitalization pattern | |
| if word.isupper(): | |
| corrected_word = corrections[clean_word].upper() | |
| elif word.istitle(): | |
| corrected_word = corrections[clean_word].capitalize() | |
| else: | |
| corrected_word = corrections[clean_word] | |
| # Preserve punctuation | |
| punctuation = word[len(clean_word):] if len(word) > len(clean_word) else '' | |
| corrected_words.append(corrected_word + punctuation) | |
| else: | |
| corrected_words.append(word) | |
| final_corrected = ' '.join(corrected_words) | |
| # Only return correction if it's significantly different | |
| if final_corrected.lower() != text.lower(): | |
| return final_corrected | |
| return text | |
| async def handle_enhanced_websocket_connection(websocket: WebSocket): | |
| """Enhanced WebSocket handler with hybrid LLM and voice features""" | |
| await websocket.accept() | |
| logger.info("π Enhanced WebSocket client connected.") | |
| # Initialize session data | |
| session_data = { | |
| "messages": [], | |
| "user_preferences": { | |
| "voice_enabled": True, # Enable voice by default since this is a voice bot | |
| "preferred_voice": "en-US-AriaNeural", | |
| "response_mode": "both" # text, voice, both - default to both for voice bot | |
| }, | |
| "context": "" | |
| } | |
| try: | |
| # Get initial connection data | |
| initial_data = await websocket.receive_json() | |
| # Validate initial data | |
| if not isinstance(initial_data, dict): | |
| logger.warning(f"β οΈ Invalid initial data format: {type(initial_data)}") | |
| initial_data = {} | |
| logger.info(f"π¨ Initial connection data: {initial_data}") | |
| # Extract user preferences | |
| if "preferences" in initial_data: | |
| session_data["user_preferences"].update(initial_data["preferences"]) | |
| # Setup user session | |
| flag = "user_id" in initial_data | |
| graph = None # Initialize graph variable | |
| if flag: | |
| thread_id = initial_data.get("user_id") | |
| knowledge_base = initial_data.get("knowledge_base", "government_docs") | |
| # Use hybrid LLM or traditional graph based on configuration | |
| if hybrid_llm_service.use_hybrid: | |
| logger.info("π€ Using Hybrid LLM Service") | |
| use_hybrid = True | |
| else: | |
| graph = await create_graph(kb_tool=True, mcp_config=None) | |
| use_hybrid = False | |
| config = { | |
| "configurable": { | |
| "thread_id": thread_id, | |
| "knowledge_base": knowledge_base, | |
| } | |
| } | |
| else: | |
| # Basic setup for unauthenticated users | |
| thread_id = str(uuid.uuid4()) | |
| knowledge_base = "government_docs" | |
| use_hybrid = hybrid_llm_service.use_hybrid | |
| if not use_hybrid: | |
| graph = create_basic_graph() | |
| config = {"configurable": {"thread_id": thread_id}} | |
| # Send initial greeting with voice/hybrid capabilities | |
| await send_enhanced_greeting(websocket, session_data) | |
| # Main message handling loop | |
| while True: | |
| try: | |
| data = await websocket.receive_json() | |
| # Validate message format | |
| if not isinstance(data, dict): | |
| logger.warning(f"β οΈ Invalid message format: {type(data)}") | |
| continue | |
| if "type" not in data: | |
| logger.warning(f"β οΈ Message missing 'type' field: {data}") | |
| continue | |
| message_type = data["type"] | |
| logger.debug(f"π¨ Received message type: {message_type}") | |
| if message_type == "text_message": | |
| await handle_text_message( | |
| websocket, data, session_data, | |
| use_hybrid, config, knowledge_base, graph | |
| ) | |
| elif message_type == "voice_message": | |
| await handle_voice_message( | |
| websocket, data, session_data, | |
| use_hybrid, config, knowledge_base, graph | |
| ) | |
| elif message_type == "preferences_update": | |
| await handle_preferences_update(websocket, data, session_data) | |
| elif message_type == "get_voice_status": | |
| await websocket.send_json({ | |
| "type": "voice_status", | |
| "data": voice_service.get_voice_status() | |
| }) | |
| elif message_type == "get_llm_status": | |
| await websocket.send_json({ | |
| "type": "llm_status", | |
| "data": hybrid_llm_service.get_provider_info() | |
| }) | |
| elif message_type == "connection": | |
| # Handle initial connection - already processed above | |
| logger.debug("π¨ Connection message received (already processed)") | |
| elif message_type == "get_knowledge_bases": | |
| # Handle knowledge base request | |
| await websocket.send_json({ | |
| "type": "knowledge_bases", | |
| "knowledge_bases": ["government_docs", "rajasthan_documents"] | |
| }) | |
| else: | |
| logger.warning(f"β οΈ Unknown message type: {message_type}") | |
| except WebSocketDisconnect: | |
| logger.info("π WebSocket client disconnected.") | |
| break | |
| except Exception as e: | |
| logger.error(f"β Error handling message: {e}") | |
| try: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"An error occurred: {str(e)}" | |
| }) | |
| except: | |
| pass # Connection might be closed | |
| except Exception as e: | |
| logger.error(f"β Error handling message: {e}") | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"An error occurred: {str(e)}" | |
| }) | |
| except WebSocketDisconnect: | |
| logger.info("π WebSocket client disconnected during setup.") | |
| except Exception as e: | |
| logger.error(f"β WebSocket error: {e}") | |
| try: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Connection error: {str(e)}" | |
| }) | |
| except: | |
| pass | |
| async def send_enhanced_greeting(websocket: WebSocket, session_data: dict): | |
| """Send enhanced greeting with system capabilities""" | |
| # Get system status | |
| llm_info = hybrid_llm_service.get_provider_info() | |
| voice_status = voice_service.get_voice_status() | |
| greeting_text = f"""π€ Welcome to the Government Document Assistant! | |
| I'm powered by a hybrid AI system that can help you with: | |
| β’ Government policies and procedures | |
| β’ Document search and analysis | |
| β’ Scenario analysis with visualizations | |
| β’ Quick answers and detailed explanations | |
| Current capabilities: | |
| β’ LLM: {'Hybrid (' + llm_info['fast_provider'] + '/' + llm_info['complex_provider'] + ')' if llm_info['hybrid_enabled'] else 'Single provider'} | |
| β’ Voice features: {'Enabled' if voice_status['voice_enabled'] else 'Disabled'} | |
| How can I assist you today? You can ask me about any government policies, procedures, or documents!""" | |
| # Send text greeting | |
| await websocket.send_json({ | |
| "type": "connection_successful", | |
| "message": greeting_text, | |
| "provider_used": "system", | |
| "capabilities": { | |
| "hybrid_llm": llm_info['hybrid_enabled'], | |
| "voice_features": voice_status['voice_enabled'], | |
| "scenario_analysis": True | |
| } | |
| }) | |
| # Send voice greeting if enabled | |
| if session_data["user_preferences"]["voice_enabled"] and voice_status['voice_enabled']: | |
| voice_greeting = "Welcome to the Government Document Assistant! I can help you with policies, procedures, and document analysis. How can I assist you today?" | |
| audio_data = await voice_service.text_to_speech(voice_greeting) | |
| if audio_data: | |
| await websocket.send_json({ | |
| "type": "audio_response", | |
| "audio_data": base64.b64encode(audio_data).decode(), | |
| "format": "mp3" | |
| }) | |
| async def handle_text_message(websocket: WebSocket, data: dict, session_data: dict, | |
| use_hybrid: bool, config: dict, knowledge_base: str, graph=None): | |
| """Handle text message with hybrid LLM""" | |
| user_message = data["message"] | |
| logger.info(f"π¬ Received text message: {user_message}") | |
| # Send acknowledgment | |
| await websocket.send_json({ | |
| "type": "message_received", | |
| "message": "Processing your message..." | |
| }) | |
| try: | |
| if use_hybrid: | |
| # Stream hybrid LLM service response | |
| response_chunks = [] | |
| provider_used = None | |
| async for chunk in get_hybrid_response( | |
| user_message, session_data["context"], config, knowledge_base, session_data.get("session_id") | |
| ): | |
| response_chunks.append(chunk) | |
| # Send each chunk as structured data | |
| await websocket.send_json({ | |
| "type": "streaming_response", | |
| "clause_text": chunk.get("clause_text", ""), | |
| "summary": chunk.get("summary", ""), | |
| "role_checklist": chunk.get("role_checklist", []), | |
| "source_title": chunk.get("source_title", ""), | |
| "clause_id": chunk.get("clause_id", ""), | |
| "date": chunk.get("date", ""), | |
| "url": chunk.get("url", ""), | |
| "score": chunk.get("score", 1.0), | |
| "scenario_analysis": chunk.get("scenario_analysis", None), | |
| "charts": chunk.get("charts", []) | |
| }) | |
| # Optionally, aggregate or select the best chunk for final response | |
| # Here, just use the first chunk for context update and provider | |
| if response_chunks: | |
| provider_used = hybrid_llm_service.choose_llm_provider(user_message) | |
| provider_used = provider_used.value if provider_used else "unknown" | |
| session_data["context"] = response_chunks[0].get("clause_text", "")[-1000:] | |
| else: | |
| # Use traditional graph approach | |
| session_data["messages"].append(HumanMessage(content=user_message)) | |
| result = await graph.ainvoke({"messages": session_data["messages"]}, config) | |
| response_text = result["messages"][-1].content | |
| provider_used = "traditional" | |
| await send_text_response(websocket, response_text, provider_used, session_data) | |
| await websocket.send_json({ | |
| "type": "llm_response", | |
| "text": "Done", | |
| "provider_used": provider_used, | |
| "timestamp": asyncio.get_event_loop().time() | |
| }) | |
| except Exception as e: | |
| logger.error(f"β Error processing text message: {e}") | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Error processing your message: {str(e)}" | |
| }) | |
| async def handle_voice_message(websocket: WebSocket, data: dict, session_data: dict, | |
| use_hybrid: bool, config: dict, knowledge_base: str, graph=None): | |
| """Handle voice message with enhanced multi-language ASR and TTS""" | |
| if not voice_service.is_voice_enabled(): | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "Voice features are not enabled" | |
| }) | |
| return | |
| try: | |
| # Get audio data - handle both old and new format | |
| if "audio_data" in data: | |
| audio_data = base64.b64decode(data["audio_data"]) | |
| else: | |
| # Handle old format or direct binary data | |
| logger.error("β No audio_data field found in voice message") | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "No audio data provided" | |
| }) | |
| return | |
| # Extract and validate user language preference | |
| user_language = data.get("lang") or data.get("language") or session_data.get("language") or session_data["user_preferences"].get("language") or "english" | |
| # Normalize language codes | |
| language_map = { | |
| 'english': 'en', 'en': 'en', 'en-us': 'en', 'en-in': 'en', | |
| 'hindi': 'hi', 'hi': 'hi', 'hi-in': 'hi', | |
| 'hinglish': 'hi-en', 'hi-en': 'hi-en', | |
| 'spanish': 'es', 'es': 'es', | |
| 'french': 'fr', 'fr': 'fr', | |
| 'german': 'de', 'de': 'de', | |
| 'portuguese': 'pt', 'pt': 'pt', | |
| 'italian': 'it', 'it': 'it', | |
| 'russian': 'ru', 'ru': 'ru', | |
| 'chinese': 'zh', 'zh': 'zh', | |
| 'japanese': 'ja', 'ja': 'ja', | |
| 'arabic': 'ar', 'ar': 'ar' | |
| } | |
| normalized_language = language_map.get(user_language.lower(), 'en') | |
| logger.info(f"π Processing voice with language: {user_language} (normalized: {normalized_language})") | |
| # Save to temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: | |
| temp_file.write(audio_data) | |
| temp_file_path = temp_file.name | |
| # Check if we should use server-side ASR or expect browser transcription | |
| if voice_service.asr_provider == "browser-native": | |
| # Expect transcription to come from browser, not from audio processing | |
| logger.info("οΏ½ Using browser-native ASR - expecting transcription from client") | |
| # Clean up temp file since we won't process it | |
| Path(temp_file_path).unlink() | |
| # Check if transcription was provided in the message | |
| if "transcription" in data: | |
| transcribed_text = data["transcription"] | |
| logger.info(f"π€ Browser transcription ({user_language}): {transcribed_text}") | |
| else: | |
| await websocket.send_json({ | |
| "type": "info", | |
| "message": "Browser ASR mode - please ensure your browser supports speech recognition" | |
| }) | |
| return | |
| else: | |
| # Use server-side ASR (Whisper) with multiple attempts if needed | |
| logger.info(f"π€ Processing audio with language preference: {user_language}") | |
| transcribed_text = await voice_service.speech_to_text(temp_file_path, normalized_language) | |
| # If transcription seems poor, try with English as fallback | |
| if transcribed_text and normalized_language != 'en': | |
| quality_check = validate_transcription_quality(transcribed_text, normalized_language) | |
| if quality_check['score'] < 0.3: | |
| logger.info("π Trying English transcription as fallback") | |
| english_transcription = await voice_service.speech_to_text(temp_file_path, 'en') | |
| if english_transcription: | |
| english_quality = validate_transcription_quality(english_transcription, 'en') | |
| if english_quality['score'] > quality_check['score'] + 0.2: | |
| logger.info(f"π― English transcription better: {english_transcription}") | |
| transcribed_text = english_transcription | |
| normalized_language = 'en' | |
| # Clean up temp file | |
| Path(temp_file_path).unlink() | |
| if not transcribed_text: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "Could not transcribe audio. Please try speaking clearly or check your microphone." | |
| }) | |
| return | |
| # Validate and potentially correct transcription | |
| transcription_quality = validate_transcription_quality(transcribed_text, normalized_language) | |
| corrected_text = attempt_transcription_correction(transcribed_text, transcription_quality) | |
| # Use corrected text if available and quality improved | |
| final_text = corrected_text if corrected_text != transcribed_text else transcribed_text | |
| final_quality = validate_transcription_quality(final_text, normalized_language) if corrected_text != transcribed_text else transcription_quality | |
| logger.info(f"π€ Transcribed ({user_language}): {transcribed_text} | Quality: {transcription_quality['score']:.2f}") | |
| if corrected_text != transcribed_text: | |
| logger.info(f"π§ Corrected to: {final_text} | New Quality: {final_quality['score']:.2f}") | |
| # Send transcription with quality info | |
| await websocket.send_json({ | |
| "type": "transcription", | |
| "text": final_text, | |
| "original_text": transcribed_text if corrected_text != transcribed_text else None, | |
| "language": user_language or "auto-detected", | |
| "confidence": final_quality['level'], | |
| "quality_score": final_quality['score'], | |
| "suggestions": final_quality['suggestions'], | |
| "was_corrected": corrected_text != transcribed_text | |
| }) | |
| # Handle low-quality transcription with detailed feedback | |
| if final_quality['score'] < 0.2: | |
| await websocket.send_json({ | |
| "type": "transcription_error", | |
| "message": f"Could not understand the audio clearly. Transcribed: '{final_text}'. Please try again with clearer speech.", | |
| "suggestions": final_quality['suggestions'], | |
| "quality_details": { | |
| "score": final_quality['score'], | |
| "garbled_ratio": final_quality.get('garbled_ratio', 0), | |
| "word_count": final_quality.get('word_count', 0) | |
| } | |
| }) | |
| return | |
| elif final_quality['score'] < 0.4: | |
| # Continue processing but warn user | |
| correction_note = f" (Auto-corrected from: '{transcribed_text}')" if corrected_text != transcribed_text else "" | |
| await websocket.send_json({ | |
| "type": "transcription_warning", | |
| "message": f"Audio quality is low (Score: {final_quality['score']:.2f}). I heard: '{final_text}'{correction_note}. Is this correct?", | |
| "suggestions": final_quality['suggestions'] + ["Try speaking more slowly", "Ensure microphone is close to your mouth", "Reduce background noise"] | |
| }) | |
| # Add comprehensive language context to the prompt for better responses | |
| language_context = create_language_context(user_language, normalized_language) | |
| enhanced_message = final_text + language_context | |
| # Process as text message with language context | |
| if use_hybrid: | |
| response_chunks = [] | |
| async for chunk in get_hybrid_response( | |
| enhanced_message, session_data["context"], config, knowledge_base, session_data.get("session_id") | |
| ): | |
| response_chunks.append(chunk) | |
| # Send each chunk as structured data | |
| await websocket.send_json({ | |
| "type": "streaming_response", | |
| "clause_text": chunk.get("clause_text", ""), | |
| "summary": chunk.get("summary", ""), | |
| "role_checklist": chunk.get("role_checklist", []), | |
| "source_title": chunk.get("source_title", ""), | |
| "clause_id": chunk.get("clause_id", ""), | |
| "date": chunk.get("date", ""), | |
| "url": chunk.get("url", ""), | |
| "score": chunk.get("score", 1.0), | |
| "scenario_analysis": chunk.get("scenario_analysis", None), | |
| "charts": chunk.get("charts", []) | |
| }) | |
| # Create response text for voice synthesis from the chunks | |
| response_text_parts = [] | |
| for chunk in response_chunks: | |
| if chunk.get("clause_text"): | |
| response_text_parts.append(chunk.get("clause_text", "")) | |
| if chunk.get("summary"): | |
| response_text_parts.append(chunk.get("summary", "")) | |
| response_text = " ".join(response_text_parts) if response_text_parts else "I found relevant information about your query." | |
| provider_used = hybrid_llm_service.choose_llm_provider(enhanced_message) | |
| provider_used = provider_used.value if provider_used else "unknown" | |
| else: | |
| session_data["messages"].append(HumanMessage(content=enhanced_message)) | |
| result = await graph.ainvoke({"messages": session_data["messages"]}, config) | |
| response_text = result["messages"][-1].content | |
| provider_used = "traditional" | |
| # Send text response | |
| await send_text_response(websocket, response_text, provider_used, session_data) | |
| # Send voice response if enabled | |
| if session_data["user_preferences"]["response_mode"] in ["voice", "both"]: | |
| # Choose appropriate voice based on user's language | |
| voice_preference = select_voice_for_language(user_language, session_data["user_preferences"]["preferred_voice"]) | |
| voice_text = voice_service.create_voice_response_with_guidance( | |
| response_text, | |
| suggested_resources=["Government portal", "Local offices", "Helpline numbers"], | |
| redirect_info="contact your local government office for personalized assistance" | |
| ) | |
| audio_response = await voice_service.text_to_speech( | |
| voice_text, | |
| voice_preference | |
| ) | |
| if audio_response: | |
| await websocket.send_bytes(audio_response) | |
| else: | |
| logger.warning("β οΈ Could not generate audio response") | |
| except Exception as e: | |
| logger.error(f"β Error processing voice message: {e}") | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Error processing voice message: {str(e)}. Please try again or switch to text mode." | |
| }) | |
| async def get_hybrid_response(user_message: str, context: str, config: dict, knowledge_base: str, session_id: str = None): | |
| """Get response using hybrid LLM with conversational clarity checks and intelligent document search""" | |
| try: | |
| # First, determine if this is a government document query or general query | |
| query_context = analyze_query_context(user_message) | |
| logger.info(f"π Query analysis: {query_context}") | |
| # Check for follow-up context from previous clarification requests | |
| follow_up_context = conversational_service.handle_follow_up(user_message, session_id) if session_id else {'is_follow_up': False} | |
| if follow_up_context['is_follow_up']: | |
| # Use enhanced query from follow-up context | |
| search_query = follow_up_context['enhanced_query'] | |
| logger.info(f"οΏ½ Using follow-up enhanced query: '{search_query}'") | |
| else: | |
| search_query = user_message | |
| logger.info(f"οΏ½π Searching documents for: '{search_query}' in knowledge base: {knowledge_base}") | |
| from rag_service import search_documents_async | |
| docs = await search_documents_async(search_query, limit=5) # Increased limit for better results | |
| logger.info(f"π Document search returned {len(docs) if docs else 0} results") | |
| # Conversational clarity analysis - check if we need clarification | |
| if not follow_up_context['is_follow_up']: # Don't ask for clarification on follow-ups | |
| conversational_analysis = conversational_service.generate_conversational_response( | |
| user_message, docs, session_id | |
| ) | |
| if conversational_analysis['needs_clarification']: | |
| logger.info("β Query needs clarification - asking user for more context") | |
| # Return clarification request instead of search results | |
| yield { | |
| "clause_text": conversational_analysis['response'], | |
| "summary": "Clarification request to better understand your question", | |
| "role_checklist": ["Please provide more specific information"], | |
| "source_title": "Conversational Assistant", | |
| "clause_id": "CLARIFICATION_REQUEST", | |
| "date": "2024", | |
| "url": "", | |
| "score": 1.0, | |
| "scenario_analysis": None, | |
| "charts": [], | |
| "needs_clarification": True, | |
| "query_type": conversational_analysis.get('query_type', 'unclear') | |
| } | |
| return | |
| # Check if we have relevant documents | |
| has_relevant_docs = docs and any(doc.get("score", 0) > 0.5 for doc in docs) | |
| # FIXED: Always try document search first, even for apparent "general" queries | |
| # This is a government document system - most queries should check documents | |
| # Only use pure LLM for very clear greetings/thanks with NO document matches | |
| very_general_keywords = ['hello', 'hi', 'thank you', 'thanks', 'goodbye', 'bye'] | |
| is_very_general = (query_context.get("type") == "general_conversation" and | |
| query_context.get("confidence", 0) > 0.8 and | |
| any(keyword in user_message.lower() for keyword in very_general_keywords) and | |
| not docs) | |
| if is_very_general: | |
| logger.info("π± Detected pure greeting/thanks with no documents, using LLM directly") | |
| llm_response = await generate_llm_fallback_response(user_message, query_context) | |
| yield { | |
| "clause_text": llm_response, | |
| "summary": "AI-generated response for general conversation", | |
| "role_checklist": ["This is general information", "For official queries, ask about government policies"], | |
| "source_title": "AI Assistant", | |
| "clause_id": "AI_GENERAL", | |
| "date": "2024", | |
| "url": "", | |
| "score": 0.9, | |
| "scenario_analysis": None, | |
| "charts": [] | |
| } | |
| return | |
| if has_relevant_docs: | |
| try: | |
| from scenario_analysis_service import run_scenario_analysis | |
| # Detect scenario analysis intent (simple keyword match) | |
| scenario_keywords = ["impact", "cost", "scenario", "multiplier", "da", "dr"] | |
| if any(kw in user_message.lower() for kw in scenario_keywords): | |
| logger.info("π Running scenario analysis") | |
| # Example params extraction (can be improved) | |
| params = { | |
| 'base_pension': 30000, | |
| 'multiplier': 1.1 if "multiplier" in user_message.lower() else 1.0, | |
| 'da_percent': 0.06 if "da" in user_message.lower() else 0.0, | |
| 'num_beneficiaries': 1000, | |
| 'years': 3, | |
| 'inflation': 0.05 | |
| } | |
| scenario_result = run_scenario_analysis(params) | |
| # Generate charts for scenario_result | |
| try: | |
| chart_gen = PolicyChartGenerator() | |
| charts = [] | |
| # Example: line chart for yearly results | |
| if "yearly_results" in scenario_result: | |
| years = [r['year'] for r in scenario_result['yearly_results']] | |
| base_costs = [r['base_cost'] for r in scenario_result['yearly_results']] | |
| scenario_costs = [r['scenario_cost'] for r in scenario_result['yearly_results']] | |
| # Generate chart and append to charts list | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.plot(years, base_costs, label='Base Cost', marker='o') | |
| ax.plot(years, scenario_costs, label='Scenario Cost', marker='s') | |
| ax.legend() | |
| ax.set_title('Scenario Analysis: Cost Over Years') | |
| ax.set_xlabel('Year') | |
| ax.set_ylabel('Cost (βΉ)') | |
| ax.grid(True, alpha=0.3) | |
| # Format y-axis to show values in lakhs | |
| ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'βΉ{x/100000:.1f}L')) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
| buf.seek(0) | |
| chart_base64 = base64.b64encode(buf.read()).decode('utf-8') | |
| plt.close(fig) | |
| charts.append({"type": "line_chart", "data": chart_base64}) | |
| logger.info(f"β Generated {len(charts)} charts for scenario analysis") | |
| scenario_result["charts"] = charts | |
| except Exception as chart_error: | |
| logger.error(f"β Failed to generate charts: {chart_error}") | |
| scenario_result["charts"] = [] | |
| scenario_result["chart_error"] = str(chart_error) | |
| else: | |
| scenario_result = None | |
| except Exception as scenario_error: | |
| logger.error(f"β Scenario analysis failed: {scenario_error}") | |
| scenario_result = None | |
| for doc in docs: | |
| response_obj = { | |
| "clause_text": doc.get("clause_text", ""), | |
| "summary": doc.get("summary", ""), | |
| "role_checklist": doc.get("role_checklist", []), | |
| "source_title": doc.get("source_title", ""), | |
| "clause_id": doc.get("clause_id", ""), | |
| "date": doc.get("date", ""), | |
| "url": doc.get("url", ""), | |
| "score": doc.get("score", 1.0), | |
| "scenario_analysis": scenario_result, | |
| "charts": scenario_result.get("charts", []) if scenario_result else [] | |
| } | |
| yield response_obj | |
| else: | |
| # No relevant documents found - use LLM fallback | |
| logger.info("π No relevant documents found, using LLM fallback") | |
| llm_response = await generate_llm_fallback_response(user_message, query_context) | |
| yield { | |
| "clause_text": llm_response, | |
| "summary": "Generated by AI assistant for general query", | |
| "role_checklist": ["Consider if this relates to government policies", "Contact relevant office for official information"], | |
| "source_title": "AI Assistant", | |
| "clause_id": "AI_001", | |
| "date": "2024", | |
| "url": "", | |
| "score": 0.8, | |
| "scenario_analysis": None, | |
| "charts": [] | |
| } | |
| except Exception as e: | |
| logger.warning(f"β Document search failed: {e}, using LLM fallback") | |
| try: | |
| llm_response = await generate_llm_fallback_response(user_message, {"type": "unknown", "confidence": 0.3}) | |
| yield { | |
| "clause_text": llm_response, | |
| "summary": "AI-generated response due to system error", | |
| "role_checklist": ["Verify information independently", "Try rephrasing your query"], | |
| "source_title": "AI Assistant (Fallback)", | |
| "clause_id": "AI_ERROR", | |
| "date": "2024", | |
| "url": "", | |
| "score": 0.5, | |
| "scenario_analysis": None, | |
| "charts": [] | |
| } | |
| except Exception as fallback_error: | |
| logger.error(f"β LLM fallback also failed: {fallback_error}") | |
| yield { | |
| "clause_text": "I apologize, but I'm experiencing technical difficulties. Please try again later or rephrase your question.", | |
| "summary": "System error occurred", | |
| "role_checklist": ["Try again later", "Rephrase your question", "Contact technical support"], | |
| "source_title": "System Error", | |
| "clause_id": "ERROR_001", | |
| "date": "2024", | |
| "url": "", | |
| "score": 0.1, | |
| "scenario_analysis": None, | |
| "charts": [] | |
| } | |
| async def send_text_response(websocket: WebSocket, response_text: str, provider_used: str, session_data: dict): | |
| """Send text response to client""" | |
| await websocket.send_json({ | |
| "type": "llm_response", | |
| "text": response_text, | |
| "provider_used": provider_used, | |
| "timestamp": asyncio.get_event_loop().time() | |
| }) | |
| # Update session context | |
| session_data["context"] = response_text[-1000:] # Keep last 1000 chars as context | |
| async def handle_scenario_response(websocket: WebSocket, response_text: str, provider_used: str): | |
| """Handle scenario analysis response with images""" | |
| parts = response_text.split("SCENARIO_ANALYSIS_IMAGE:") | |
| text_part = parts[0].strip() | |
| # Send text part | |
| if text_part: | |
| await websocket.send_json({ | |
| "type": "llm_response", | |
| "text": text_part, | |
| "provider_used": provider_used | |
| }) | |
| # Send image parts | |
| for i, part in enumerate(parts[1:], 1): | |
| try: | |
| image_data = part.strip() | |
| await websocket.send_json({ | |
| "type": "scenario_image", | |
| "image_data": image_data, | |
| "image_index": i, | |
| "chart_type": "analysis" | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error sending scenario image {i}: {e}") | |
| async def handle_preferences_update(websocket: WebSocket, data: dict, session_data: dict): | |
| """Handle user preferences update""" | |
| try: | |
| session_data["user_preferences"].update(data["preferences"]) | |
| await websocket.send_json({ | |
| "type": "preferences_updated", | |
| "preferences": session_data["user_preferences"] | |
| }) | |
| logger.info(f"π§ Updated user preferences: {session_data['user_preferences']}") | |
| except Exception as e: | |
| logger.error(f"β Error updating preferences: {e}") | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Error updating preferences: {str(e)}" | |
| }) | |
| # Keep the original function for backward compatibility | |
| async def handle_websocket_connection(websocket: WebSocket): | |
| """Original websocket handler for backward compatibility""" | |
| await handle_enhanced_websocket_connection(websocket) | |