from fastapi import FastAPI, Request, HTTPException from fastapi.responses import JSONResponse from transformers import AutoModelForCausalLM, AutoTokenizer import torch import os import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Get Hugging Face Space configuration HF_SPACE = os.getenv("SPACE_ID", "") BASE_PATH = f"/spaces/{HF_SPACE}" if HF_SPACE else "" # Initialize FastAPI with correct base path app = FastAPI( title="DialoGPT API", description="Chatbot API using Microsoft's DialoGPT-medium model", version="1.0", root_path=BASE_PATH, docs_url="/docs" if not BASE_PATH else f"{BASE_PATH}/docs", redoc_url=None ) # Load model and tokenizer try: logger.info("Loading tokenizer and model...") tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Model loading failed: {str(e)}") raise RuntimeError("Model initialization failed") from e # In-memory chat history storage chat_history = {} @app.get("/", include_in_schema=False) async def root(): return {"message": "🟢 API is running. Use /ai?query=Hello&user_id=yourname"} @app.get("/ai") async def chat(request: Request): try: # Get query parameters user_input = request.query_params.get("query", "").strip() user_id = request.query_params.get("user_id", "default").strip() # Validate input if not user_input: raise HTTPException( status_code=400, detail="Missing 'query' parameter. Usage: /ai?query=Hello&user_id=yourname" ) if len(user_input) > 200: raise HTTPException( status_code=400, detail="Query too long (max 200 characters)" ) # Process the query new_input_ids = tokenizer.encode( user_input + tokenizer.eos_token, return_tensors='pt' ) # Retrieve user history user_history = chat_history.get(user_id, []) # Generate bot response bot_input_ids = torch.cat(user_history + [new_input_ids], dim=-1) if user_history else new_input_ids output_ids = model.generate( bot_input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id, do_sample=True, top_k=50, top_p=0.95 ) # Decode and clean response response = tokenizer.decode( output_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True ).strip() # Update chat history chat_history[user_id] = [bot_input_ids, output_ids] return {"reply": response} except torch.cuda.OutOfMemoryError: logger.error("CUDA out of memory error") # Clear history to free memory if user_id in chat_history: del chat_history[user_id] raise HTTPException( status_code=500, detail="Memory error. Conversation history cleared. Please try again." ) except Exception as e: logger.error(f"Processing error: {str(e)}") raise HTTPException( status_code=500, detail=f"Processing error: {str(e)}" ) from e @app.get("/health") async def health_check(): return { "status": "healthy", "model": "microsoft/DialoGPT-medium", "users": len(chat_history), "space_id": HF_SPACE } @app.get("/reset") async def reset_history(user_id: str = "default"): if user_id in chat_history: del chat_history[user_id] return {"status": "success", "message": f"History cleared for user {user_id}"} # Only run with uvicorn when executing locally if __name__ == "__main__": import uvicorn uvicorn.run( app, host="0.0.0.0", port=7860, log_level="info", timeout_keep_alive=30 )