AmirMohseni's picture
Update app.py
74d4e40 verified
import gradio as gr
from transformers import pipeline
router = pipeline(
"text-classification",
model="AmirMohseni/reasoning-router-0.6b",
device_map="auto",
)
# --- 2. Define the classification function ---
def classify_prompt(prompt: str) -> dict:
"""
Classifies the user prompt into 'think' or 'no_think' and returns a dictionary
formatted for Gradio's Label component.
"""
if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0:
return {} # Return empty dict for invalid input
# Run inference
results = router(prompt, top_k=None) # Get scores for both labels
# Format for Gradio Label output
output_dict = {item['label']: item['score'] for item in results}
return output_dict
# --- 3. Build the Gradio Interface ---
with gr.Blocks(theme='soft', title="Reasoning Router") as demo:
# Header
gr.Markdown(
"""
# 🧠 Reasoning Router 0.6B
This is a demo for the `AmirMohseni/reasoning-router-0.6b` model.
It classifies user prompts into two categories:
- **think** → The task requires complex reasoning (e.g., math, multi-step analysis).
- **no_think** → The task is simple and can be handled by a lightweight model.
Enter a prompt below to see how the model classifies it. This is useful for building efficient, hybrid model systems.
"""
)
# Main interface
with gr.Row():
prompt_input = gr.Textbox(
label="Enter Prompt",
placeholder="e.g., If a train travels at 60 mph, how long does it take to cover 180 miles?",
lines=3
)
classification_output = gr.Label(label="Classification Result", num_top_classes=2)
classify_button = gr.Button("Classify", variant="primary")
classify_button.click(
fn=classify_prompt,
inputs=prompt_input,
outputs=classification_output
)
# Example prompts
gr.Examples(
[
"What is the sum of the first 100 prime numbers?",
"What is the capital of France?",
"Solve for x in the equation 3x - 10 = 2.",
"List the ingredients for a chocolate cake.",
"An isosceles trapezoid has an inscribed circle tangent to each of its four sides. The radius of the circle is $3$, and the area of the trapezoid is $72$. Let the parallel sides of the trapezoid have lengths $r$ and $s$, with $r \neq s$. Find $r^2+s^2$"
],
inputs=prompt_input,
outputs=classification_output,
fn=classify_prompt,
cache_examples=True
)
# Launch the app
if __name__ == "__main__":
demo.launch()