|
|
from fastapi import FastAPI, Request, HTTPException |
|
|
from fastapi.responses import JSONResponse |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
import os |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
HF_SPACE = os.getenv("SPACE_ID", "") |
|
|
BASE_PATH = f"/spaces/{HF_SPACE}" if HF_SPACE else "" |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Phi-2 Chat API", |
|
|
description="Chatbot API using microsoft/phi-2, CPU-optimized", |
|
|
version="1.0", |
|
|
root_path=BASE_PATH, |
|
|
docs_url="/docs" if not BASE_PATH else f"{BASE_PATH}/docs", |
|
|
redoc_url=None |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info("Loading Phi-2 tokenizer and model...") |
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") |
|
|
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2") |
|
|
logger.info("Model loaded successfully!") |
|
|
except Exception as e: |
|
|
logger.error(f"Model loading failed: {str(e)}") |
|
|
raise RuntimeError("Model initialization failed") from e |
|
|
|
|
|
|
|
|
chat_history = {} |
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = ( |
|
|
"You are a helpful, chill, clever, and fun AI assistant called π΄ ππ πππ. " |
|
|
"Talk like a smooth, witty friend. Be friendly and humanlike.\n" |
|
|
) |
|
|
|
|
|
@app.get("/", include_in_schema=False) |
|
|
async def root(): |
|
|
return {"message": "π’ Phi-2 API is live. Use /ai?query=Hello&user_id=yourname"} |
|
|
|
|
|
@app.get("/ai") |
|
|
async def chat(request: Request): |
|
|
try: |
|
|
user_input = request.query_params.get("query", "").strip() |
|
|
user_id = request.query_params.get("user_id", "default").strip() |
|
|
|
|
|
if not user_input: |
|
|
raise HTTPException(status_code=400, detail="Missing 'query' parameter.") |
|
|
if len(user_input) > 200: |
|
|
raise HTTPException(status_code=400, detail="Query too long (max 200 characters)") |
|
|
|
|
|
|
|
|
user_history = chat_history.get(user_id, []) |
|
|
history_prompt = "" |
|
|
|
|
|
for entry in user_history[-3:]: |
|
|
history_prompt += f"User: {entry['q']}\nAI: {entry['a']}\n" |
|
|
full_prompt = SYSTEM_PROMPT + history_prompt + f"User: {user_input}\nAI:" |
|
|
|
|
|
|
|
|
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids |
|
|
output_ids = model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=100, |
|
|
temperature=0.8, |
|
|
top_p=0.95, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True).strip() |
|
|
|
|
|
|
|
|
user_history.append({"q": user_input, "a": response}) |
|
|
chat_history[user_id] = user_history |
|
|
|
|
|
return {"reply": response} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") from e |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model": "microsoft/phi-2", |
|
|
"users": len(chat_history), |
|
|
"space_id": HF_SPACE |
|
|
} |
|
|
|
|
|
@app.get("/reset") |
|
|
async def reset_history(user_id: str = "default"): |
|
|
if user_id in chat_history: |
|
|
del chat_history[user_id] |
|
|
return {"status": "success", "message": f"History cleared for user {user_id}"} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=7860, |
|
|
log_level="info", |
|
|
timeout_keep_alive=30 |
|
|
) |