File size: 16,062 Bytes
1a5a832
 
 
af7014e
1a5a832
 
4f93f43
3698240
d61dc23
9996b89
afcc6ac
3698240
afcc6ac
af7014e
d8e142e
eaf3553
afcc6ac
d8e142e
4f93f43
afcc6ac
 
3698240
 
3cbbfb9
 
 
3698240
 
afcc6ac
872279f
 
 
afcc6ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3698240
c63db98
d8e142e
 
 
 
 
 
 
 
c86a2d5
37d1dcb
c86a2d5
37d1dcb
 
87e24e1
 
c86a2d5
 
d8e142e
3698240
 
 
d8e142e
c63db98
 
d8e142e
1a5a832
c86a2d5
d8e142e
 
 
 
 
 
 
 
 
 
872279f
d8e142e
 
 
 
3698240
 
c63db98
 
 
 
 
 
99beedf
1a5a832
99beedf
 
1a5a832
 
 
 
 
 
c63db98
1a5a832
c63db98
 
 
 
 
6ae9204
99beedf
c63db98
9d40e2a
1a5a832
d1c0fdc
c63db98
 
 
1a5a832
 
 
 
 
 
 
af7014e
 
 
 
 
 
 
1a5a832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af7014e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a5a832
 
af7014e
 
1a5a832
 
af7014e
 
1a5a832
 
af7014e
 
1a5a832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3953183
 
1a5a832
 
 
 
 
 
 
 
af7014e
1a5a832
 
 
 
 
 
af7014e
 
 
 
 
 
 
3953183
 
af7014e
 
 
 
 
1a5a832
 
 
29350b7
9996b89
62a8602
1a5a832
af7014e
1a5a832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af7014e
 
 
9996b89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af7014e
 
3953183
af7014e
d1c0fdc
af7014e
9996b89
 
af7014e
 
 
 
4395834
af7014e
 
29350b7
 
 
1a5a832
99beedf
1a5a832
 
af7014e
 
 
3953183
 
 
 
 
 
 
 
b2773c1
29350b7
7a247ef
 
 
29350b7
 
3953183
4395834
4f93f43
d8e142e
 
7a247ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
import base64
import pathlib
import re
import time
from io import BytesIO

import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageChops, ImageDraw
from fastai.callback.core import Callback
from fastai.learner import *
from fastai.torch_core import TitledStr
from html2image import Html2Image
# from min_dalle import MinDalle
from torch import tensor, Tensor, float16, float32
from torch.distributions import Transform
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler

# These utility functions need to be in main (or otherwise where created) because fastai loads from that module, see:
# https://docs.fast.ai/learner.html#load_learner
from transformers import GPT2TokenizerFast

import os
AUTH_TOKEN = os.environ.get('AUTH_TOKEN')

# update requirements.txt with:
# C:\Users\Grant\PycharmProjects\test_space\venv\Scripts\pip3.exe freeze > requirements.txt

# Huggingface Spaces have 16GB RAM and 8 CPU cores
# See https://huggingface.co/docs/hub/spaces-overview#hardware-resources

pretrained_weights = 'gpt2'
tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_weights)


def tokenize(text):
    toks = tokenizer.tokenize(text)
    return tensor(tokenizer.convert_tokens_to_ids(toks))


class TransformersTokenizer(Transform):
    def __init__(self, tokenizer): self.tokenizer = tokenizer

    def encodes(self, x):
        return x if isinstance(x, Tensor) else tokenize(x)

    def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy()))


class DropOutput(Callback):
    def after_pred(self): self.learn.pred = self.pred[0]


# initialize only once
# Takes about 2 minutes (126 seconds) to generate an image in Huggingface spaces on CPU
# NOTE as of 2022-11-13 min-dalle is broken, switch to using a stable diffusion model for images
# model = MinDalle(
#     models_root='./pretrained',
#     dtype=float32,
#     device='cpu',
#     is_mega=True,
#     is_reusable=True
# )
# Download pipeline, but overwrite scheduler
# Consider DPMSolverMultistepScheduler once added to diffusers
from diffusers import EulerAncestralDiscreteScheduler
scheduler = EulerAncestralDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler",
                                                        use_auth_token=AUTH_TOKEN)
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_mega",
                                             torch_dtype=torch.float32,
                                             scheduler=scheduler, use_auth_token=AUTH_TOKEN)
# pipe.enable_attention_slicing()
# pipeline.to("cuda")


def gen_image(prompt):
    prompt = f"{prompt}, fantasy painting by Greg Rutkowski"
    # See https://huggingface.co/spaces/pootow/min-dalle/blob/main/app.py
    # Hugging Space faces seems to run out of memory if grads are not disabled
    # torch.set_grad_enabled(False)
    print(f'RUNNING gen_image with prompt: {prompt}')
    images = pipeline.text2img(prompt, width=256, height=256, num_inference_steps=20).images
    # images = model.generate_images(
    #     text=prompt,
    #     seed=-1,
    #     grid_size=1,  # grid size above 2 causes out of memory on 12 GB 3080Ti; grid size 2 gives 4 images
    #     is_seamless=False,
    #     temperature=1,
    #     top_k=256,
    #     supercondition_factor=16,
    #     is_verbose=True
    # )
    print('COMPLETED GENERATION')
    # images = images.to('cpu').numpy()
    # images = images.astype(np.uint8)
    # return Image.fromarray(images[0])
    return images[0]


gpu = False
# init only once
learner = load_learner('export.pkl',
                       cpu=not gpu)  # cpu=False uses GPU; make sure installed torch is GPU e.g. `pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116`


def parse_monster_description(name, text):
    match = re.search(r"Description: (.*)", text)
    if not match:
        return f"{name} is a monster."
    description = match.group(1)
    print(description.split('.')[0])
    return description.split('.')[0]


def gen_monster_text(name):
    prompt = f"Name: {name}\r\n"
    print(f'GENERATING MONSTER TEXT with prompt: {prompt}')
    prompt_ids = tokenizer.encode(prompt)
    if gpu:
        inp = tensor(prompt_ids)[None].cuda()  # Use .cuda() for torch GPU
    else:
        inp = tensor(prompt_ids)[None]
    preds = learner.model.generate(inp, max_length=512, num_beams=5, temperature=1.5, do_sample=True,
                                   repetition_penalty=1.2)
    result = tokenizer.decode(preds[0].cpu().numpy())
    result = result.split('###')[0].replace(r'\r\n', '\n').replace('\r', '').replace(r'\r', '')
    print(f'GENERATING MONSTER COMPLETE')
    print(result)
    return result


def extract_text_for_header(text, header):
    match = re.search(fr"{header}: (.*)", text)
    if match is None:
        return ''
    return match.group(1)


def remove_section(html, html_class):
    match = re.search(f'<li class="{html_class}"([\w\W])*?li>', html)
    if match is not None:
        html = html.replace(match.group(0), '')
        return html


def format_monster_card(monster_text, image_data):
    print('FORMATTING MONSTER TEXT')
    # see giffyglyph's monster maker https://giffyglyph.com/monstermaker/app/
    # Different Formatting style examples and some json export formats
    card = pathlib.Path('monsterMakerTemplate.html').read_text()
    if not isinstance(image_data, (bytes, bytearray)):
        card = card.replace('{image_data}', f'{image_data}')
    else:
        card = card.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}')
    name = extract_text_for_header(monster_text, 'Name')
    card = card.replace('{name}', name)
    monster_type = extract_text_for_header(monster_text, 'Type')
    card = card.replace('{monster_type}', monster_type)
    armor_class = extract_text_for_header(monster_text, 'Armor Class')
    card = card.replace('{armor_class}', armor_class)
    hit_points = extract_text_for_header(monster_text, 'Hit Points')
    card = card.replace('{hit_points}', hit_points)
    speed = extract_text_for_header(monster_text, 'Speed')
    card = card.replace('{speed}', speed)
    str_stat = extract_text_for_header(monster_text, 'STR')
    card = card.replace('{str_stat}', str_stat)
    dex_stat = extract_text_for_header(monster_text, 'DEX')
    card = card.replace('{dex_stat}', dex_stat)
    con_stat = extract_text_for_header(monster_text, 'CON')
    card = card.replace('{con_stat}', con_stat)
    int_stat = extract_text_for_header(monster_text, 'INT')
    card = card.replace('{int_stat}', int_stat)
    wis_stat = extract_text_for_header(monster_text, 'WIS')
    card = card.replace('{wis_stat}', wis_stat)
    cha_stat = extract_text_for_header(monster_text, 'CHA')
    card = card.replace('{cha_stat}', cha_stat)
    saving_throws = extract_text_for_header(monster_text, 'Saving Throws')
    card = card.replace('{saving_throws}', saving_throws)
    if not saving_throws:
        card = remove_section(card, 'monster-saves')
    skills = extract_text_for_header(monster_text, 'Skills')
    card = card.replace('{skills}', skills)
    if not skills:
        card = remove_section(card, 'monster-skills')
    damage_vulnerabilities = extract_text_for_header(monster_text, 'Damage Vulnerabilities')
    card = card.replace('{damage_vulnerabilities}', damage_vulnerabilities)
    if not damage_vulnerabilities:
        card = remove_section(card, 'monster-vulnerabilities')
    damage_resistances = extract_text_for_header(monster_text, 'Damage Resistances')
    card = card.replace('{damage_resistances}', damage_resistances)
    if not damage_resistances:
        card = remove_section(card, 'monster-resistances')
    damage_immunities = extract_text_for_header(monster_text, 'Damage Immunities')
    card = card.replace('{damage_immunities}', damage_immunities)
    if not damage_immunities:
        card = remove_section(card, 'monster-immunities')
    condition_immunities = extract_text_for_header(monster_text, 'Condition Immunities')
    card = card.replace('{condition_immunities}', condition_immunities)
    if not condition_immunities:
        card = remove_section(card, 'monster-conditions')
    senses = extract_text_for_header(monster_text, 'Senses')
    card = card.replace('{senses}', senses)
    if not senses:
        card = remove_section(card, 'monster-senses')
    languages = extract_text_for_header(monster_text, 'Languages')
    card = card.replace('{languages}', languages)
    if not languages:
        card = remove_section(card, 'monster-languages')
    challenge = extract_text_for_header(monster_text, 'Challenge')
    card = card.replace('{challenge}', challenge)
    if not challenge:
        card = remove_section(card, 'monster-challenge')

    description = extract_text_for_header(monster_text, 'Description')
    card = card.replace('{description}', description)

    match = re.search(r"Passives:\n([\w\W]*)", monster_text)
    if match is None:
        passives = ''
    else:
        passives = match.group(1)
    p = passives.split(':')
    if len(p) > 1:
        p = ":".join(p)
        p = p.split('\n')
        passives_data = ''
        for x in p:
            x = x.split(':')
            if len(x) > 1:
                trait = x[0]
                if trait == "Passives":
                    continue
                if 'Action' in trait:
                    break
                detail = ":".join(x[1:])
                passives_data += f'<div class="monster-trait"><p><span class="name">{trait}</span> <span class="detail">{detail}</span></p></div>'
        card = card.replace('{passives}', passives_data)
    else:
        card = card.replace('{passives}', f'<div class="monster-trait"><p>{passives}</p></div>')

    match = re.search(r"Actions:\n([\w\W]*)", monster_text)
    if match is None:
        actions = ''
    else:
        actions = match.group(1)
    a = actions.split(':')
    if len(a) > 1:
        a = ":".join(a)
        a = a.split('\n')
        actions_data = ''
        for x in a:
            x = x.split(':')
            if len(x) > 1:
                action = x[0]
                if action == "Actions":
                    continue
                if 'Passive' in action:
                    break
                detail = ":".join(x[1:])
                actions_data += f'<div class="monster-action"><p><span class="name">{action}</span> <span class="detail">{detail}</span></p></div>'
        card = card.replace('{actions}', actions_data)
    else:
        card = card.replace('{actions}', f'<div class="monster-action"><p>{actions}</p></div>')

    # TODO: Legendary actions, reactions, make column count for format an option (1 or 2 column layout)

    card = card.replace('Melee or Ranged Weapon Attack:', '<i>Melee or Ranged Weapon Attack:</i>')
    card = card.replace('Melee Weapon Attack:', '<i>Melee Weapon Attack:</i>')
    card = card.replace('Ranged Weapon Attack:', '<i>Ranged Weapon Attack:</i>')
    card = card.replace('Hit:', '<i>Hit:</i>')

    print('FORMATTING MONSTER TEXT COMPLETE')
    return card


def pil_to_base64(image):
    print('CONVERTING PIL IMAGE TO BASE64 STRING')
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue())
    print('CONVERTING PIL IMAGE TO BASE64 STRING COMPLETE')
    return img_str


hti = Html2Image(output_path='rendered_cards')


def trim(im, border):
    bg = Image.new(im.mode, im.size, border)
    diff = ImageChops.difference(im, bg)
    bbox = diff.getbbox()
    if bbox:
        return im.crop(bbox)


def crop_background(image):
    white = (255, 255, 255)
    ImageDraw.floodfill(image, (image.size[0] - 1, 0), white, thresh=50)
    image = trim(image, white)
    return image


def html_to_png(html):
    print('CONVERTING HTML CARD TO PNG IMAGE')
    paths = hti.screenshot(html_str=html, css_file="monstermaker.css", save_as="test.png", size=(800, 1440))
    path = paths[0]
    print('OPENING IMAGE FROM FILE')
    img = Image.open(path).convert("RGB")
    print('CROPPING BACKGROUND')
    img = crop_background(img)
    print('CONVERTING HTML CARD TO PNG IMAGE COMPLETE')
    return img


def run(name: str) -> (Image, str, Image, str):
    start = time.time()
    print(f'BEGINNING RUN FOR {name}')
    if not name:
        placeholder_image = Image.new(mode="RGB", size=(256, 256))
        return placeholder_image, 'No name provided; enter a name and try again', placeholder_image, ''
    text = gen_monster_text(name)
    description = parse_monster_description(name, text)
    pil = gen_image(description)
    image_data = pil_to_base64(pil)
    card_html = format_monster_card(text, image_data)
    card_image = html_to_png(card_html)
    end = time.time()
    print(f'RUN COMPLETED IN {int(end - start)} seconds')
    return card_image, text, pil, card_html


app_description = (
    """
    # Create your own D&D monster!
    Enter a name, click Submit, and wait for about 4 minutes to see the result.
    """).strip()
input_box = gr.Textbox(label="Enter a monster name", placeholder="Jabberwock")
output_monster_card = gr.Image(label="Monster Card", type='pil', value="examples/jabberwock_card.png")
output_text_box = gr.Textbox(label="Monster Text", value=pathlib.Path("examples/jabberwock.txt").read_text('utf-8'))
output_monster_image = gr.Image(label="Monster Image", type='pil', value="examples/jabberwock.png")
output_monster_html = gr.HTML(label="Monster HTML", visible=False, show_label=False)
x = gr.components.Textbox()
iface = gr.Interface(title="MonsterGen", theme="default", description=app_description, fn=run, inputs=[input_box],
                     outputs=[output_monster_card, output_text_box, output_monster_image, output_monster_html])
iface.launch()
# TODO: Add examples, larger language model?, document process, log silences, "Passives" => "Traits", log timestamps
# Fine tune dalle-mini? https://blog.paperspace.com/dalle-mini/
# API works, assuming query takes no longer than 30 seconds (504 gateway timeout)
# Looks like API page improvements are in progress: https://github.com/gradio-app/gradio/issues/1325
# Example code below:
# import requests
# r = requests.post(url='https://hf.space/embed/gstaff/test_space/+/api/predict', json={"data": [""]},
#                   timeout=None)
# print(r.json())

# Looks like Huggingface uses the queue push api, then polls for status:
# fetch("https://hf.space/embed/gstaff/test_space/api/queue/push/", {
#   "headers": {
#     "accept": "*/*",
#     "accept-language": "en-US,en;q=0.9",
#     "content-type": "application/json",
#     "sec-ch-ua": "\".Not/A)Brand\";v=\"99\", \"Google Chrome\";v=\"103\", \"Chromium\";v=\"103\"",
#     "sec-ch-ua-mobile": "?0",
#     "sec-ch-ua-platform": "\"Windows\"",
#     "sec-fetch-dest": "empty",
#     "sec-fetch-mode": "cors",
#     "sec-fetch-site": "same-origin"
#   },
#   "referrer": "https://hf.space/embed/gstaff/test_space/+?__theme=light",
#   "referrerPolicy": "strict-origin-when-cross-origin",
#   "body": "{\"fn_index\":0,\"data\":[\"Jabberwock\"],\"action\":\"predict\",\"session_hash\":\"v9ehgfho3p\"}",
#   "method": "POST",
#   "mode": "cors",
#   "credentials": "omit"
# });

# fetch("https://hf.space/embed/gstaff/test_space/api/queue/status/", {
#   "headers": {
#     "accept": "*/*",
#     "accept-language": "en-US,en;q=0.9",
#     "content-type": "application/json",
#     "sec-ch-ua": "\".Not/A)Brand\";v=\"99\", \"Google Chrome\";v=\"103\", \"Chromium\";v=\"103\"",
#     "sec-ch-ua-mobile": "?0",
#     "sec-ch-ua-platform": "\"Windows\"",
#     "sec-fetch-dest": "empty",
#     "sec-fetch-mode": "cors",
#     "sec-fetch-site": "same-origin"
#   },
#   "referrer": "https://hf.space/embed/gstaff/test_space/+?__theme=light",
#   "referrerPolicy": "strict-origin-when-cross-origin",
#   "body": "{\"hash\":\"09f5369a7a414169aa48948bad5fd93d\"}",
#   "method": "POST",
#   "mode": "cors",
#   "credentials": "omit"
# });