""" RPG Room Generator App Sets the sampling parameters and provides minimal interface to the user https://huggingface.co/blog/how-to-generate """ import gradio as gr from gradio import inputs # allows easier doc lookup in Pycharm import transformers as tr MPATH = "./models/mdl_roomgen7" MODEL = tr.GPT2LMHeadModel.from_pretrained(MPATH) # ToDo: Will save tokenizer next time so can replace this with a load SPECIAL_TOKENS = { 'eos_token': '<|EOS|>', 'bos_token': '<|endoftext|>', 'pad_token': '', 'sep_token': '<|body|>' } TOK = tr.GPT2Tokenizer.from_pretrained("gpt2") TOK.add_special_tokens(SPECIAL_TOKENS) SAMPLING_OPTIONS = { "Reasonable": { "top_k": 25, "temperature": 50, "top_p": 60 }, "Odd": { "top_k": 50, "temperature": 75, "top_p": 90 }, "Insane": { "top_k": 300, "temperature": 100, "top_p": 85 }, } def generate_room(room_name, room_desc, max_length, sampling_method): """ Uses pretrained model to generate text for a dungeon room Returns: Room description text """ prompt = " ".join( [ SPECIAL_TOKENS["bos_token"], room_name, SPECIAL_TOKENS["sep_token"], room_desc ] ) # Only want to skip the room name part to_skip = TOK.encode(" ".join([SPECIAL_TOKENS["bos_token"], room_name, SPECIAL_TOKENS["sep_token"]]), return_tensors="pt") ids = TOK.encode(prompt, return_tensors="pt") # Sample top_k = SAMPLING_OPTIONS[sampling_method]["top_k"] temperature = SAMPLING_OPTIONS[sampling_method]["temperature"] / 100. top_p = SAMPLING_OPTIONS[sampling_method]["top_p"] / 100. output = MODEL.generate( ids, max_length=max_length, do_sample=True, top_k=top_k, temperature=temperature, top_p=top_p ) output = TOK.decode(output[0][to_skip.shape[1]:], clean_up_tokenization_spaces=True).replace(" ", " ") # Slice off last partial sentence last_period = output.rfind(".") if last_period > 0: output = output[:last_period+1] return output if __name__ == "__main__": iface = gr.Interface( title="RPG Room Generator", fn=generate_room, inputs=[ inputs.Textbox(lines=1, label="Room Name"), inputs.Textbox(lines=3, label="Start of Room Description (Optional)", default=""), inputs.Slider(minimum=50, maximum=1000, default=200, label="Length"), inputs.Radio(choices=list(SAMPLING_OPTIONS.keys()), default="Odd", label="Craziness"), ], outputs="text", layout="horizontal", allow_flagging="never", theme="dark", ) app, local_url, share_url = iface.launch()