File size: 4,147 Bytes
5afdd5f 08aad81 a7c32b2 5afdd5f 16ce850 5afdd5f 74c9bed eea02e8 5afdd5f 08aad81 5afdd5f a7c32b2 d28821f 5afdd5f 642d8b4 bf28dd1 2754c6f 5afdd5f 08aad81 5afdd5f d28821f 5afdd5f d28821f 5afdd5f bf28dd1 5afdd5f eea02e8 5afdd5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Get Hugging Face Space configuration
HF_SPACE = os.getenv("SPACE_ID", "")
BASE_PATH = f"/spaces/{HF_SPACE}" if HF_SPACE else ""
# Initialize FastAPI with correct base path
app = FastAPI(
title="DialoGPT API",
description="Chatbot API using Microsoft's DialoGPT-medium model",
version="1.0",
root_path=BASE_PATH,
docs_url="/docs" if not BASE_PATH else f"{BASE_PATH}/docs",
redoc_url=None
)
# Load model and tokenizer
try:
logger.info("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Model loading failed: {str(e)}")
raise RuntimeError("Model initialization failed") from e
# In-memory chat history storage
chat_history = {}
@app.get("/", include_in_schema=False)
async def root():
return {"message": "🟢 API is running. Use /ai?query=Hello&user_id=yourname"}
@app.get("/ai")
async def chat(request: Request):
try:
# Get query parameters
user_input = request.query_params.get("query", "").strip()
user_id = request.query_params.get("user_id", "default").strip()
# Validate input
if not user_input:
raise HTTPException(
status_code=400,
detail="Missing 'query' parameter. Usage: /ai?query=Hello&user_id=yourname"
)
if len(user_input) > 200:
raise HTTPException(
status_code=400,
detail="Query too long (max 200 characters)"
)
# Process the query
new_input_ids = tokenizer.encode(
user_input + tokenizer.eos_token,
return_tensors='pt'
)
# Retrieve user history
user_history = chat_history.get(user_id, [])
# Generate bot response
bot_input_ids = torch.cat(user_history + [new_input_ids], dim=-1) if user_history else new_input_ids
output_ids = model.generate(
bot_input_ids,
max_new_tokens=100,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
top_k=50,
top_p=0.95
)
# Decode and clean response
response = tokenizer.decode(
output_ids[:, bot_input_ids.shape[-1]:][0],
skip_special_tokens=True
).strip()
# Update chat history
chat_history[user_id] = [bot_input_ids, output_ids]
return {"reply": response}
except torch.cuda.OutOfMemoryError:
logger.error("CUDA out of memory error")
# Clear history to free memory
if user_id in chat_history:
del chat_history[user_id]
raise HTTPException(
status_code=500,
detail="Memory error. Conversation history cleared. Please try again."
)
except Exception as e:
logger.error(f"Processing error: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Processing error: {str(e)}"
) from e
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"model": "microsoft/DialoGPT-medium",
"users": len(chat_history),
"space_id": HF_SPACE
}
@app.get("/reset")
async def reset_history(user_id: str = "default"):
if user_id in chat_history:
del chat_history[user_id]
return {"status": "success", "message": f"History cleared for user {user_id}"}
# Only run with uvicorn when executing locally
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info",
timeout_keep_alive=30
) |