Spaces:
Runtime error
Runtime error
| # | |
| import logging | |
| from hf_olmo import OLMoForCausalLM, OLMoTokenizerFast | |
| # Enable logging | |
| logging.basicConfig( | |
| format="%(asctime)s - %(name)s - %(lineno)s - %(funcName)s - %(levelname)s - %(message)s", | |
| level=logging.INFO | |
| ) | |
| # set higher logging level for httpx to avoid all GET and POST requests being logged | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| logger = logging.getLogger(__name__) | |
| MODEL = "allenai/OLMo-7B-Instruct" | |
| olmo = OLMoForCausalLM.from_pretrained(MODEL) | |
| tokenizer = OLMoTokenizerFast.from_pretrained(MODEL) | |
| chat = [ | |
| {"role": "user", | |
| "content": "What is language modeling?"}, | |
| ] | |
| prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") | |
| # optional verifying cuda | |
| # inputs = {k: v.to('cuda') for k,v in inputs.items()} | |
| # olmo = olmo.to('cuda') | |
| response = olmo.generate(input_ids=inputs.to(olmo.device), max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95) | |
| print(tokenizer.batch_decode(response, skip_special_tokens=True)[0]) |