File size: 4,147 Bytes
5afdd5f
08aad81
a7c32b2
 
5afdd5f
 
16ce850
5afdd5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74c9bed
eea02e8
5afdd5f
 
 
 
 
 
 
 
08aad81
5afdd5f
a7c32b2
d28821f
5afdd5f
642d8b4
bf28dd1
2754c6f
 
 
5afdd5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08aad81
5afdd5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d28821f
5afdd5f
 
 
 
 
 
 
 
d28821f
5afdd5f
 
 
 
 
bf28dd1
5afdd5f
eea02e8
 
5afdd5f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Get Hugging Face Space configuration
HF_SPACE = os.getenv("SPACE_ID", "")
BASE_PATH = f"/spaces/{HF_SPACE}" if HF_SPACE else ""

# Initialize FastAPI with correct base path
app = FastAPI(
    title="DialoGPT API",
    description="Chatbot API using Microsoft's DialoGPT-medium model",
    version="1.0",
    root_path=BASE_PATH,
    docs_url="/docs" if not BASE_PATH else f"{BASE_PATH}/docs",
    redoc_url=None
)

# Load model and tokenizer
try:
    logger.info("Loading tokenizer and model...")
    tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
    model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
    logger.info("Model loaded successfully!")
except Exception as e:
    logger.error(f"Model loading failed: {str(e)}")
    raise RuntimeError("Model initialization failed") from e

# In-memory chat history storage
chat_history = {}

@app.get("/", include_in_schema=False)
async def root():
    return {"message": "🟢 API is running. Use /ai?query=Hello&user_id=yourname"}

@app.get("/ai")
async def chat(request: Request):
    try:
        # Get query parameters
        user_input = request.query_params.get("query", "").strip()
        user_id = request.query_params.get("user_id", "default").strip()
        
        # Validate input
        if not user_input:
            raise HTTPException(
                status_code=400,
                detail="Missing 'query' parameter. Usage: /ai?query=Hello&user_id=yourname"
            )
        if len(user_input) > 200:
            raise HTTPException(
                status_code=400,
                detail="Query too long (max 200 characters)"
            )

        # Process the query
        new_input_ids = tokenizer.encode(
            user_input + tokenizer.eos_token,
            return_tensors='pt'
        )
        
        # Retrieve user history
        user_history = chat_history.get(user_id, [])
        
        # Generate bot response
        bot_input_ids = torch.cat(user_history + [new_input_ids], dim=-1) if user_history else new_input_ids
        output_ids = model.generate(
            bot_input_ids,
            max_new_tokens=100,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True,
            top_k=50,
            top_p=0.95
        )
        
        # Decode and clean response
        response = tokenizer.decode(
            output_ids[:, bot_input_ids.shape[-1]:][0],
            skip_special_tokens=True
        ).strip()
        
        # Update chat history
        chat_history[user_id] = [bot_input_ids, output_ids]
        
        return {"reply": response}
    
    except torch.cuda.OutOfMemoryError:
        logger.error("CUDA out of memory error")
        # Clear history to free memory
        if user_id in chat_history:
            del chat_history[user_id]
        raise HTTPException(
            status_code=500,
            detail="Memory error. Conversation history cleared. Please try again."
        )
    
    except Exception as e:
        logger.error(f"Processing error: {str(e)}")
        raise HTTPException(
            status_code=500,
            detail=f"Processing error: {str(e)}"
        ) from e

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "model": "microsoft/DialoGPT-medium",
        "users": len(chat_history),
        "space_id": HF_SPACE
    }

@app.get("/reset")
async def reset_history(user_id: str = "default"):
    if user_id in chat_history:
        del chat_history[user_id]
    return {"status": "success", "message": f"History cleared for user {user_id}"}

# Only run with uvicorn when executing locally
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=7860,
        log_level="info",
        timeout_keep_alive=30
    )