Llama3-crypto / app.py
shanaka95's picture
01-19
12c6d1d
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel
import torch
# Replace 'your-huggingface-hub-repo' with your repository name or URL
REPO_NAME = "shanaka95/autotrain-01-24"
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(REPO_NAME)
# Load the base model (ensure compatibility with the adapter)
base_model_name = "meta-llama/Llama-3.1-8B-Instruct" # Replace with the name of the original base model
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16)
# Load the adapter model
adapter_model = PeftModel.from_pretrained(base_model, REPO_NAME)
# Merge the adapter weights into the base model for inference
adapter_model.eval()
# Create a pipeline for text generation
text_generation = pipeline(
"text-generation",
model=adapter_model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1 # Use GPU if available
)
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
):
input_text = [
{'role': 'system', 'content': "You are a highly knowledgeable crypto market analysis expert specializing in BTC/USDT. Using your expertise in both fundamental and technical analysis, evaluate the provided market price data and market news for BTC/USDT from today and the previous four days. After thorough analysis, assist in determining whether tomorrow's BTC/USDT price is likely to increase or decrease compared to today. If the price is predicted to increase, estimate an average maximum price, if the price is predicted to decrease, estimate an average minimum price."},
{'role': 'user', 'content': message}
]
generated_texts = text_generation(
input_text,
max_new_tokens=max_tokens,
num_return_sequences=1,
temperature=temperature,
top_p=top_p
)
return generated_texts[0]['generated_text'][2]['content']
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=128, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.4, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch()