AmirMohseni commited on
Commit
a468fc4
·
verified ·
1 Parent(s): 9051814

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import torch
4
+
5
+ try:
6
+ router = pipeline(
7
+ "text-classification",
8
+ model="AmirMohseni/reasoning-router-0.6b",
9
+ device_map="auto",
10
+ torch_dtype=torch.bfloat16
11
+ )
12
+ except Exception as e:
13
+ print(f"Error loading model: {e}")
14
+ # Fallback to CPU if specific GPU setup fails
15
+ router = pipeline(
16
+ "text-classification",
17
+ model="AmirMohseni/reasoning-router-0.6b",
18
+ )
19
+
20
+
21
+ # --- 2. Define the classification function ---
22
+ def classify_prompt(prompt: str) -> dict:
23
+ """
24
+ Classifies the user prompt into 'think' or 'no_think' and returns a dictionary
25
+ formatted for Gradio's Label component.
26
+ """
27
+ if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0:
28
+ return {} # Return empty dict for invalid input
29
+
30
+ # Run inference
31
+ results = router(prompt, top_k=None) # Get scores for both labels
32
+
33
+ # Format for Gradio Label output
34
+ output_dict = {item['label']: item['score'] for item in results}
35
+ return output_dict
36
+
37
+
38
+ # --- 3. Build the Gradio Interface ---
39
+ with gr.Blocks(theme='soft', title="Reasoning Router") as demo:
40
+ # Header
41
+ gr.Markdown(
42
+ """
43
+ # 🧠 Reasoning Router 0.6B
44
+ This is a demo for the `AmirMohseni/reasoning-router-0.6b` model.
45
+ It classifies user prompts into two categories:
46
+ - **think** → The task requires complex reasoning (e.g., math, multi-step analysis).
47
+ - **no_think** → The task is simple and can be handled by a lightweight model.
48
+
49
+ Enter a prompt below to see how the model classifies it. This is useful for building efficient, hybrid model systems.
50
+ """
51
+ )
52
+
53
+ # Main interface
54
+ with gr.Row():
55
+ prompt_input = gr.Textbox(
56
+ label="Enter Prompt",
57
+ placeholder="e.g., If a train travels at 60 mph, how long does it take to cover 180 miles?",
58
+ lines=3
59
+ )
60
+ classification_output = gr.Label(label="Classification Result", num_top_classes=2)
61
+
62
+ classify_button = gr.Button("Classify", variant="primary")
63
+ classify_button.click(
64
+ fn=classify_prompt,
65
+ inputs=prompt_input,
66
+ outputs=classification_output
67
+ )
68
+
69
+ # Example prompts
70
+ gr.Examples(
71
+ [
72
+ "What is the sum of the first 100 prime numbers?",
73
+ "What is the capital of France?",
74
+ "Solve for x in the equation 3x - 10 = 2.",
75
+ "Can you write me a short poem about the moon?",
76
+ "Explain the theory of relativity in simple terms.",
77
+ "List the ingredients for a chocolate cake."
78
+ ],
79
+ inputs=prompt_input,
80
+ outputs=classification_output,
81
+ fn=classify_prompt,
82
+ cache_examples=True
83
+ )
84
+
85
+ # Launch the app
86
+ if __name__ == "__main__":
87
+ demo.launch()