import gradio as gr from transformers import pipeline router = pipeline( "text-classification", model="AmirMohseni/reasoning-router-0.6b", device_map="auto", ) # --- 2. Define the classification function --- def classify_prompt(prompt: str) -> dict: """ Classifies the user prompt into 'think' or 'no_think' and returns a dictionary formatted for Gradio's Label component. """ if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0: return {} # Return empty dict for invalid input # Run inference results = router(prompt, top_k=None) # Get scores for both labels # Format for Gradio Label output output_dict = {item['label']: item['score'] for item in results} return output_dict # --- 3. Build the Gradio Interface --- with gr.Blocks(theme='soft', title="Reasoning Router") as demo: # Header gr.Markdown( """ # 🧠 Reasoning Router 0.6B This is a demo for the `AmirMohseni/reasoning-router-0.6b` model. It classifies user prompts into two categories: - **think** → The task requires complex reasoning (e.g., math, multi-step analysis). - **no_think** → The task is simple and can be handled by a lightweight model. Enter a prompt below to see how the model classifies it. This is useful for building efficient, hybrid model systems. """ ) # Main interface with gr.Row(): prompt_input = gr.Textbox( label="Enter Prompt", placeholder="e.g., If a train travels at 60 mph, how long does it take to cover 180 miles?", lines=3 ) classification_output = gr.Label(label="Classification Result", num_top_classes=2) classify_button = gr.Button("Classify", variant="primary") classify_button.click( fn=classify_prompt, inputs=prompt_input, outputs=classification_output ) # Example prompts gr.Examples( [ "What is the sum of the first 100 prime numbers?", "What is the capital of France?", "Solve for x in the equation 3x - 10 = 2.", "List the ingredients for a chocolate cake.", "An isosceles trapezoid has an inscribed circle tangent to each of its four sides. The radius of the circle is $3$, and the area of the trapezoid is $72$. Let the parallel sides of the trapezoid have lengths $r$ and $s$, with $r \neq s$. Find $r^2+s^2$" ], inputs=prompt_input, outputs=classification_output, fn=classify_prompt, cache_examples=True ) # Launch the app if __name__ == "__main__": demo.launch()