import gradio as gr import torch import spaces import subprocess import sys # Install specific transformers version subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers==4.48.3"]) from transformers import AutoTokenizer, AutoModelForCausalLM # Load model and tokenizer model_name = "nvidia/NVIDIA-Nemotron-Nano-9B-v2" tokenizer = AutoTokenizer.from_pretrained(model_name) model = None def load_model(): global model if model is None: model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto" ) return model @spaces.GPU(duration=120) def generate_response(message, history, enable_reasoning, temperature, top_p, max_tokens): """Generate response from the model""" # Prepare messages with reasoning control messages = [] # Add system message based on reasoning setting if enable_reasoning: messages.append({"role": "system", "content": "/think"}) else: messages.append({"role": "system", "content": "/no_think"}) # Add conversation history for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) # Add current message messages.append({"role": "user", "content": message}) # Load model if needed model = load_model() # Tokenize the conversation tokenized_chat = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ).to(model.device) # Set generation parameters based on reasoning mode if enable_reasoning: # Recommended settings for reasoning generation_kwargs = { "temperature": temperature if temperature > 0 else 0.6, "top_p": top_p if top_p < 1 else 0.95, "do_sample": True, "max_new_tokens": max_tokens, "eos_token_id": tokenizer.eos_token_id } else: # Greedy search for non-reasoning generation_kwargs = { "do_sample": False, "max_new_tokens": max_tokens, "eos_token_id": tokenizer.eos_token_id } # Generate response with torch.no_grad(): outputs = model.generate(tokenized_chat, **generation_kwargs) # Decode and extract the assistant's response generated_tokens = outputs[0][tokenized_chat.shape[-1]:] # Get only new tokens response = tokenizer.decode(generated_tokens, skip_special_tokens=True) return response # Create Gradio interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # NVIDIA Nemotron Nano 9B v2 Chatbot This chatbot uses the NVIDIA Nemotron Nano 9B v2 model with optional reasoning capabilities. - **Enable Reasoning**: Activates the model's chain-of-thought reasoning (/think mode) - **Disable Reasoning**: Uses direct response generation (/no_think mode) **Note:** Using transformers version 4.48.3 as recommended by the model documentation. """ ) chatbot = gr.Chatbot(height=500) msg = gr.Textbox( label="Message", placeholder="Type your message here...", lines=2 ) with gr.Row(): submit = gr.Button("Send", variant="primary") clear = gr.Button("Clear") with gr.Accordion("Advanced Settings", open=False): enable_reasoning = gr.Checkbox( label="Enable Reasoning (/think mode)", value=True, info="Enable chain-of-thought reasoning for complex queries" ) temperature = gr.Slider( minimum=0.0, maximum=2.0, value=0.6, step=0.1, label="Temperature", info="Controls randomness (recommended: 0.6 for reasoning, ignored for non-reasoning)" ) top_p = gr.Slider( minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-p", info="Controls diversity (recommended: 0.95 for reasoning, ignored for non-reasoning)" ) max_tokens = gr.Slider( minimum=32, maximum=2048, value=1024, step=32, label="Max New Tokens", info="Maximum number of tokens to generate (recommended: 1024+ for reasoning)" ) def user_submit(message, history): return "", history + [[message, None]] def bot_response(history, enable_reasoning, temperature, top_p, max_tokens): if not history: return history message = history[-1][0] try: response = generate_response( message, history[:-1], enable_reasoning, temperature, top_p, max_tokens ) history[-1][1] = response except Exception as e: history[-1][1] = f"Error generating response: {str(e)}" return history msg.submit( user_submit, [msg, chatbot], [msg, chatbot], queue=False ).then( bot_response, [chatbot, enable_reasoning, temperature, top_p, max_tokens], chatbot ) submit.click( user_submit, [msg, chatbot], [msg, chatbot], queue=False ).then( bot_response, [chatbot, enable_reasoning, temperature, top_p, max_tokens], chatbot ) clear.click(lambda: None, None, chatbot, queue=False) # Example prompts gr.Examples( examples=[ "Write a haiku about GPUs", "Explain quantum computing in simple terms", "What is the capital of France?", "Solve this step by step: If a train travels 120 miles in 2 hours, what is its average speed?", "Write a short story about a robot learning to paint" ], inputs=msg ) if __name__ == "__main__": demo.launch()