Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,6 +9,9 @@ from transformers import (
|
|
| 9 |
TextIteratorStreamer,
|
| 10 |
)
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
MODEL_ID = os.getenv("MODEL_ID", "yasserrmd/SoftwareArchitecture-Instruct-v1")
|
| 13 |
|
| 14 |
# -------- Load model & tokenizer --------
|
|
@@ -56,45 +59,45 @@ def format_history_as_messages(history):
|
|
| 56 |
messages.append({"role": "assistant", "content": a})
|
| 57 |
return messages
|
| 58 |
|
|
|
|
| 59 |
def stream_generate(messages, max_new_tokens, temperature, top_p, repetition_penalty, seed=None):
|
| 60 |
-
"""
|
| 61 |
-
Stream text from model.generate using TextIteratorStreamer.
|
| 62 |
-
"""
|
| 63 |
if seed is not None and seed >= 0:
|
| 64 |
torch.manual_seed(seed)
|
| 65 |
|
| 66 |
inputs = tokenizer.apply_chat_template(
|
| 67 |
messages,
|
| 68 |
-
add_generation_prompt=True,
|
| 69 |
return_tensors="pt",
|
| 70 |
tokenize=True,
|
| 71 |
return_dict=True,
|
| 72 |
)
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 76 |
gen_kwargs = dict(
|
| 77 |
**inputs,
|
| 78 |
-
max_new_tokens=max_new_tokens,
|
| 79 |
temperature=float(temperature),
|
| 80 |
top_p=float(top_p),
|
| 81 |
repetition_penalty=float(repetition_penalty),
|
| 82 |
-
do_sample=
|
| 83 |
use_cache=True,
|
| 84 |
streamer=streamer,
|
| 85 |
)
|
| 86 |
|
| 87 |
-
# Run generation in a thread so we can yield from streamer
|
| 88 |
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
|
| 89 |
thread.start()
|
| 90 |
|
| 91 |
-
|
| 92 |
-
for
|
| 93 |
-
|
| 94 |
-
yield
|
| 95 |
|
| 96 |
# -------- Gradio callbacks --------
|
| 97 |
-
|
| 98 |
def chat_respond(user_msg, chat_history, max_new_tokens, temperature, top_p, repetition_penalty, seed):
|
| 99 |
if not user_msg or not user_msg.strip():
|
| 100 |
return gr.update(), chat_history
|
|
|
|
| 9 |
TextIteratorStreamer,
|
| 10 |
)
|
| 11 |
|
| 12 |
+
import spaces
|
| 13 |
+
|
| 14 |
+
|
| 15 |
MODEL_ID = os.getenv("MODEL_ID", "yasserrmd/SoftwareArchitecture-Instruct-v1")
|
| 16 |
|
| 17 |
# -------- Load model & tokenizer --------
|
|
|
|
| 59 |
messages.append({"role": "assistant", "content": a})
|
| 60 |
return messages
|
| 61 |
|
| 62 |
+
@spaces.GPU
|
| 63 |
def stream_generate(messages, max_new_tokens, temperature, top_p, repetition_penalty, seed=None):
|
|
|
|
|
|
|
|
|
|
| 64 |
if seed is not None and seed >= 0:
|
| 65 |
torch.manual_seed(seed)
|
| 66 |
|
| 67 |
inputs = tokenizer.apply_chat_template(
|
| 68 |
messages,
|
| 69 |
+
add_generation_prompt=True,
|
| 70 |
return_tensors="pt",
|
| 71 |
tokenize=True,
|
| 72 |
return_dict=True,
|
| 73 |
)
|
| 74 |
+
|
| 75 |
+
# Keep only what the model expects
|
| 76 |
+
allowed = {"input_ids", "attention_mask"} # no token_type_ids for causal LMs
|
| 77 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items() if k in allowed}
|
| 78 |
|
| 79 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 80 |
gen_kwargs = dict(
|
| 81 |
**inputs,
|
| 82 |
+
max_new_tokens=int(max_new_tokens),
|
| 83 |
temperature=float(temperature),
|
| 84 |
top_p=float(top_p),
|
| 85 |
repetition_penalty=float(repetition_penalty),
|
| 86 |
+
do_sample=temperature > 0,
|
| 87 |
use_cache=True,
|
| 88 |
streamer=streamer,
|
| 89 |
)
|
| 90 |
|
|
|
|
| 91 |
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
|
| 92 |
thread.start()
|
| 93 |
|
| 94 |
+
partial = ""
|
| 95 |
+
for chunk in streamer:
|
| 96 |
+
partial += chunk
|
| 97 |
+
yield partial
|
| 98 |
|
| 99 |
# -------- Gradio callbacks --------
|
| 100 |
+
@spaces.GPU
|
| 101 |
def chat_respond(user_msg, chat_history, max_new_tokens, temperature, top_p, repetition_penalty, seed):
|
| 102 |
if not user_msg or not user_msg.strip():
|
| 103 |
return gr.update(), chat_history
|