Spaces:
Runtime error
Runtime error
Provide default description for image generation if missing, add repetition penalty of 1.2 to text generation
Browse files
app.py
CHANGED
|
@@ -86,8 +86,10 @@ learner = load_learner('export.pkl',
|
|
| 86 |
cpu=not gpu) # cpu=False uses GPU; make sure installed torch is GPU e.g. `pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116`
|
| 87 |
|
| 88 |
|
| 89 |
-
def parse_monster_description(text):
|
| 90 |
match = re.search(r"Description: (.*)", text)
|
|
|
|
|
|
|
| 91 |
description = match.group(1)
|
| 92 |
print(description.split('.')[0])
|
| 93 |
return description.split('.')[0]
|
|
@@ -101,7 +103,8 @@ def gen_monster_text(name):
|
|
| 101 |
inp = tensor(prompt_ids)[None].cuda() # Use .cuda() for torch GPU
|
| 102 |
else:
|
| 103 |
inp = tensor(prompt_ids)[None]
|
| 104 |
-
preds = learner.model.generate(inp, max_length=1024, num_beams=5, temperature=1.5, do_sample=True
|
|
|
|
| 105 |
result = tokenizer.decode(preds[0].cpu().numpy())
|
| 106 |
result = result.split('###')[0].replace(r'\r\n', '\n').replace('\r', '').replace(r'\r', '')
|
| 107 |
print(f'GENERATING MONSTER COMPLETE')
|
|
@@ -299,7 +302,7 @@ def run(name: str) -> (Image, str, Image, str):
|
|
| 299 |
placeholder_image = Image.new(mode="RGB", size=(256, 256))
|
| 300 |
return placeholder_image, 'No name provided; enter a name and try again', placeholder_image, ''
|
| 301 |
text = gen_monster_text(name)
|
| 302 |
-
description = parse_monster_description(text)
|
| 303 |
pil = gen_image(description)
|
| 304 |
image_data = pil_to_base64(pil)
|
| 305 |
card_html = format_monster_card(text, image_data)
|
|
|
|
| 86 |
cpu=not gpu) # cpu=False uses GPU; make sure installed torch is GPU e.g. `pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116`
|
| 87 |
|
| 88 |
|
| 89 |
+
def parse_monster_description(name, text):
|
| 90 |
match = re.search(r"Description: (.*)", text)
|
| 91 |
+
if not match:
|
| 92 |
+
return f"{name} is a monster."
|
| 93 |
description = match.group(1)
|
| 94 |
print(description.split('.')[0])
|
| 95 |
return description.split('.')[0]
|
|
|
|
| 103 |
inp = tensor(prompt_ids)[None].cuda() # Use .cuda() for torch GPU
|
| 104 |
else:
|
| 105 |
inp = tensor(prompt_ids)[None]
|
| 106 |
+
preds = learner.model.generate(inp, max_length=1024, num_beams=5, temperature=1.5, do_sample=True,
|
| 107 |
+
repetition_penalty=1.2)
|
| 108 |
result = tokenizer.decode(preds[0].cpu().numpy())
|
| 109 |
result = result.split('###')[0].replace(r'\r\n', '\n').replace('\r', '').replace(r'\r', '')
|
| 110 |
print(f'GENERATING MONSTER COMPLETE')
|
|
|
|
| 302 |
placeholder_image = Image.new(mode="RGB", size=(256, 256))
|
| 303 |
return placeholder_image, 'No name provided; enter a name and try again', placeholder_image, ''
|
| 304 |
text = gen_monster_text(name)
|
| 305 |
+
description = parse_monster_description(name, text)
|
| 306 |
pil = gen_image(description)
|
| 307 |
image_data = pil_to_base64(pil)
|
| 308 |
card_html = format_monster_card(text, image_data)
|