Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI, HTTPException | |
| from langchain.prompts import PromptTemplate | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| from dotenv import load_dotenv | |
| from embeddings.embeddings import generate_embeddings | |
| from elastic.retrieval import search_certification_chunks | |
| from prompting.rewrite_question import classify_certification, initialize_llms, process_query | |
| load_dotenv() | |
| app = FastAPI( | |
| title="Hydrogen Certification RAG System", | |
| description="API for querying hydrogen certification documents using RAG", | |
| version="0.1.0" | |
| ) | |
| # Initialize LLMs and Elasticsearch client | |
| llms = initialize_llms() | |
| # Request models | |
| class QueryRequest(BaseModel): | |
| query: str | |
| llm = initialize_llms()["rewrite_llm"] | |
| # Endpoints | |
| async def handle_query(request: QueryRequest): | |
| """ | |
| Process a query through the full RAG pipeline: | |
| 1. Classify certification (if not provided) | |
| 2. Optimize query based on specificity | |
| 3. Search relevant chunks | |
| """ | |
| try: | |
| # Step 1: Determine certification | |
| query = request.query | |
| certification = classify_certification(request.query, llms["rewrite_llm"]) | |
| if "no certification mentioned" in certification : | |
| raise HTTPException( | |
| status_code=400, | |
| detail="No certification specified in query and none provided" | |
| ) | |
| # Step 2: Process query | |
| processed_query = process_query(request.query, llms) | |
| question_vector = generate_embeddings(processed_query) | |
| # Step 3: Search | |
| results = search_certification_chunks( | |
| index_name="certif_index", | |
| certification_name=certification, | |
| text_query=processed_query, | |
| vector_query=question_vector, | |
| ) | |
| results_ = search_certification_chunks( | |
| index_name="certification_index", | |
| certification_name=certification, | |
| text_query=processed_query, | |
| vector_query=question_vector, | |
| ) | |
| results_list = [result["text"] for result in results] | |
| results_list_ = [result["text"] for result in results_] | |
| results_merged = ". ".join([result["text"] for result in results]) | |
| results_merged_ = ". ".join([result["text"] for result in results_]) | |
| template = """ | |
| You are an AI assistant tasked with providing answers based on the given context about a specific hydrogen certification. | |
| Provide a clear, concise response that directly addresses the question without unnecessary information. | |
| Question: {question} | |
| Certification: {certification} | |
| Context: {context} | |
| Answer: | |
| """ | |
| prompt = PromptTemplate( | |
| input_variables=["question", "certification", "context"], | |
| template=template | |
| ) | |
| chain = prompt | llm | |
| answer = chain.invoke({"question": processed_query, "certification": certification, "context": results_merged}).content | |
| answer_ = chain.invoke({"question": processed_query, "certification": certification, "context": results_merged_}).content | |
| return { | |
| "certification": certification, | |
| "certif_index": answer, | |
| "certification_index": answer_, | |
| "context_certif": results_list, | |
| "context_certifications": results_list_ | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_certifications(): | |
| """List all available certifications""" | |
| try: | |
| certs_dir = "docs/processed" | |
| return [f for f in os.listdir(certs_dir) if os.path.isdir(os.path.join(certs_dir, f))] | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |