|
|
from fastapi import FastAPI, Request, HTTPException |
|
|
from fastapi.responses import JSONResponse |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
import os |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
HF_SPACE = os.getenv("SPACE_ID", "") |
|
|
BASE_PATH = f"/spaces/{HF_SPACE}" if HF_SPACE else "" |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="DialoGPT API", |
|
|
description="Chatbot API using Microsoft's DialoGPT-medium model", |
|
|
version="1.0", |
|
|
root_path=BASE_PATH, |
|
|
docs_url="/docs" if not BASE_PATH else f"{BASE_PATH}/docs", |
|
|
redoc_url=None |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info("Loading tokenizer and model...") |
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") |
|
|
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") |
|
|
logger.info("Model loaded successfully!") |
|
|
except Exception as e: |
|
|
logger.error(f"Model loading failed: {str(e)}") |
|
|
raise RuntimeError("Model initialization failed") from e |
|
|
|
|
|
|
|
|
chat_history = {} |
|
|
|
|
|
@app.get("/", include_in_schema=False) |
|
|
async def root(): |
|
|
return {"message": "🟢 API is running. Use /ai?query=Hello&user_id=yourname"} |
|
|
|
|
|
@app.get("/ai") |
|
|
async def chat(request: Request): |
|
|
try: |
|
|
|
|
|
user_input = request.query_params.get("query", "").strip() |
|
|
user_id = request.query_params.get("user_id", "default").strip() |
|
|
|
|
|
|
|
|
if not user_input: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="Missing 'query' parameter. Usage: /ai?query=Hello&user_id=yourname" |
|
|
) |
|
|
if len(user_input) > 200: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="Query too long (max 200 characters)" |
|
|
) |
|
|
|
|
|
|
|
|
new_input_ids = tokenizer.encode( |
|
|
user_input + tokenizer.eos_token, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
|
|
|
user_history = chat_history.get(user_id, []) |
|
|
|
|
|
|
|
|
bot_input_ids = torch.cat(user_history + [new_input_ids], dim=-1) if user_history else new_input_ids |
|
|
output_ids = model.generate( |
|
|
bot_input_ids, |
|
|
max_new_tokens=100, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
do_sample=True, |
|
|
top_k=50, |
|
|
top_p=0.95 |
|
|
) |
|
|
|
|
|
|
|
|
response = tokenizer.decode( |
|
|
output_ids[:, bot_input_ids.shape[-1]:][0], |
|
|
skip_special_tokens=True |
|
|
).strip() |
|
|
|
|
|
|
|
|
chat_history[user_id] = [bot_input_ids, output_ids] |
|
|
|
|
|
return {"reply": response} |
|
|
|
|
|
except torch.cuda.OutOfMemoryError: |
|
|
logger.error("CUDA out of memory error") |
|
|
|
|
|
if user_id in chat_history: |
|
|
del chat_history[user_id] |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail="Memory error. Conversation history cleared. Please try again." |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Processing error: {str(e)}") |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Processing error: {str(e)}" |
|
|
) from e |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model": "microsoft/DialoGPT-medium", |
|
|
"users": len(chat_history), |
|
|
"space_id": HF_SPACE |
|
|
} |
|
|
|
|
|
@app.get("/reset") |
|
|
async def reset_history(user_id: str = "default"): |
|
|
if user_id in chat_history: |
|
|
del chat_history[user_id] |
|
|
return {"status": "success", "message": f"History cleared for user {user_id}"} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=7860, |
|
|
log_level="info", |
|
|
timeout_keep_alive=30 |
|
|
) |