gstaff commited on
Commit
99beedf
·
1 Parent(s): 7a247ef

Provide default description for image generation if missing, add repetition penalty of 1.2 to text generation

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -86,8 +86,10 @@ learner = load_learner('export.pkl',
86
  cpu=not gpu) # cpu=False uses GPU; make sure installed torch is GPU e.g. `pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116`
87
 
88
 
89
- def parse_monster_description(text):
90
  match = re.search(r"Description: (.*)", text)
 
 
91
  description = match.group(1)
92
  print(description.split('.')[0])
93
  return description.split('.')[0]
@@ -101,7 +103,8 @@ def gen_monster_text(name):
101
  inp = tensor(prompt_ids)[None].cuda() # Use .cuda() for torch GPU
102
  else:
103
  inp = tensor(prompt_ids)[None]
104
- preds = learner.model.generate(inp, max_length=1024, num_beams=5, temperature=1.5, do_sample=True)
 
105
  result = tokenizer.decode(preds[0].cpu().numpy())
106
  result = result.split('###')[0].replace(r'\r\n', '\n').replace('\r', '').replace(r'\r', '')
107
  print(f'GENERATING MONSTER COMPLETE')
@@ -299,7 +302,7 @@ def run(name: str) -> (Image, str, Image, str):
299
  placeholder_image = Image.new(mode="RGB", size=(256, 256))
300
  return placeholder_image, 'No name provided; enter a name and try again', placeholder_image, ''
301
  text = gen_monster_text(name)
302
- description = parse_monster_description(text)
303
  pil = gen_image(description)
304
  image_data = pil_to_base64(pil)
305
  card_html = format_monster_card(text, image_data)
 
86
  cpu=not gpu) # cpu=False uses GPU; make sure installed torch is GPU e.g. `pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116`
87
 
88
 
89
+ def parse_monster_description(name, text):
90
  match = re.search(r"Description: (.*)", text)
91
+ if not match:
92
+ return f"{name} is a monster."
93
  description = match.group(1)
94
  print(description.split('.')[0])
95
  return description.split('.')[0]
 
103
  inp = tensor(prompt_ids)[None].cuda() # Use .cuda() for torch GPU
104
  else:
105
  inp = tensor(prompt_ids)[None]
106
+ preds = learner.model.generate(inp, max_length=1024, num_beams=5, temperature=1.5, do_sample=True,
107
+ repetition_penalty=1.2)
108
  result = tokenizer.decode(preds[0].cpu().numpy())
109
  result = result.split('###')[0].replace(r'\r\n', '\n').replace('\r', '').replace(r'\r', '')
110
  print(f'GENERATING MONSTER COMPLETE')
 
302
  placeholder_image = Image.new(mode="RGB", size=(256, 256))
303
  return placeholder_image, 'No name provided; enter a name and try again', placeholder_image, ''
304
  text = gen_monster_text(name)
305
+ description = parse_monster_description(name, text)
306
  pil = gen_image(description)
307
  image_data = pil_to_base64(pil)
308
  card_html = format_monster_card(text, image_data)