File size: 3,643 Bytes
dbee570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f513b9
 
dbee570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f513b9
 
dbee570
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Hugging Face Spaces path setup
HF_SPACE = os.getenv("SPACE_ID", "")
BASE_PATH = f"/spaces/{HF_SPACE}" if HF_SPACE else ""

# FastAPI initialization
app = FastAPI(
    title="Phi-2 Chat API",
    description="Chatbot API using microsoft/phi-2, CPU-optimized",
    version="1.0",
    root_path=BASE_PATH,
    docs_url="/docs" if not BASE_PATH else f"{BASE_PATH}/docs",
    redoc_url=None
)

# Load model and tokenizer
try:
    logger.info("Loading Phi-2 tokenizer and model...")
    tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
    model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
    logger.info("Model loaded successfully!")
except Exception as e:
    logger.error(f"Model loading failed: {str(e)}")
    raise RuntimeError("Model initialization failed") from e

# In-memory chat history
chat_history = {}

# System prompt to guide tone
SYSTEM_PROMPT = (
    "You are a helpful, chill, clever, and fun AI assistant called 𝕴 𝖆𝖒 π–π–Žπ–’. "
    "Talk like a smooth, witty friend. Be friendly and humanlike.\n"
)

@app.get("/", include_in_schema=False)
async def root():
    return {"message": "🟒 Phi-2 API is live. Use /ai?query=Hello&user_id=yourname"}

@app.get("/ai")
async def chat(request: Request):
    try:
        user_input = request.query_params.get("query", "").strip()
        user_id = request.query_params.get("user_id", "default").strip()

        if not user_input:
            raise HTTPException(status_code=400, detail="Missing 'query' parameter.")
        if len(user_input) > 200:
            raise HTTPException(status_code=400, detail="Query too long (max 200 characters)")

        # Retrieve last conversation
        user_history = chat_history.get(user_id, [])
        history_prompt = ""

        for entry in user_history[-3:]:  # Last 3 exchanges
            history_prompt += f"User: {entry['q']}\nAI: {entry['a']}\n"
        full_prompt = SYSTEM_PROMPT + history_prompt + f"User: {user_input}\nAI:"

        # Tokenize and generate
        input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
        output_ids = model.generate(
            input_ids,
            max_new_tokens=100,
            temperature=0.8,
            top_p=0.95,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )

        response = tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True).strip()

        # Store updated history
        user_history.append({"q": user_input, "a": response})
        chat_history[user_id] = user_history

        return {"reply": response}

    except Exception as e:
        logger.error(f"Error: {e}")
        raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") from e

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "model": "microsoft/phi-2",
        "users": len(chat_history),
        "space_id": HF_SPACE
    }

@app.get("/reset")
async def reset_history(user_id: str = "default"):
    if user_id in chat_history:
        del chat_history[user_id]
    return {"status": "success", "message": f"History cleared for user {user_id}"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=7860,
        log_level="info",
        timeout_keep_alive=30
    )