File size: 7,987 Bytes
3545cb6
64cb722
3545cb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import os
import torch
import vtracer
import tempfile
import cairosvg
import re
from PIL import Image
from datetime import datetime

from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS

from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler

import torchvision.transforms as transforms
from model import Generator 

def setup_directories():
    os.makedirs(SVG_DIR, exist_ok=True)
    os.makedirs(THUMBNAIL_DIR, exist_ok=True)
    print(f"Directories '{SVG_DIR}' and '{THUMBNAIL_DIR}' are ready.")

def sanitize_filename(prompt):
    """Removes characters that are invalid for filenames."""

    s = re.sub(r'[\\/*?:"<>|]', "", prompt)

    return s[:100]

SVG_DIR = os.path.join(os.getcwd(), 'generated_svgs')
THUMBNAIL_DIR = os.path.join(os.getcwd(), 'thumbnails')
SKETCH_MODEL_WEIGHTS = 'checkpoints/netG_A_latest.pth'

class ImageToSvgPipeline:
    """
    A class to handle the entire pipeline from text prompt to SVG.
    Initializes models once to be reused.
    """
    def __init__(self, sketch_model_path: str):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")

        self._initialize_rinna_model()
        self._initialize_sketch_model(sketch_model_path)

    def _initialize_rinna_model(self):
        print("Loading Rinna Stable Diffusion model...")
        model_id = "rinna/japanese-stable-diffusion"

        self.rinna_pipe = StableDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
        )
        self.rinna_pipe.scheduler = LMSDiscreteScheduler(
            beta_start=0.00085, beta_end=0.012, 
            beta_schedule="scaled_linear", num_train_timesteps=1000
        )
        self.rinna_pipe.tokenizer.model_max_length = 77
        self.rinna_pipe.to(self.device)
        print("Rinna model loaded.")

    def _initialize_sketch_model(self, model_path: str):
        print(f"Loading Sketch Generator model from {model_path}...")
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Sketch model weights not found at: {model_path}")

        self.sketch_model = Generator(input_nc=3, output_nc=1, n_residual_blocks=3)
        self.sketch_model.to(self.device)
        self.sketch_model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.sketch_model.eval()

        self.sketch_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        print("Sketch model loaded.")

    def _generate_image(self, prompt: str, negative_prompt: str, steps: int = 30) -> Image.Image:
        print(f"Generating image for prompt: '{prompt}'")
        with torch.no_grad():
            image = self.rinna_pipe(
                prompt,
                negative_prompt=negative_prompt,
                num_inference_steps=steps,
                guidance_scale=7.5,
                width=512,
                height=512,
            ).images[0]
        return image

    def _convert_to_sketch(self, image: Image.Image) -> Image.Image:
        print("Converting image to sketch...")
        with torch.no_grad():
            input_tensor = self.sketch_transform(image.convert("RGB")).unsqueeze(0).to(self.device)
            output_tensor = self.sketch_model(input_tensor)
            output_tensor = output_tensor.squeeze(0).cpu()
            sketch_image = transforms.ToPILImage()(output_tensor)
        return sketch_image

    def _extract_svg(self, image: Image.Image) -> str:
        print("Extracting SVG from sketch...")
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
            image.save(tmp_file.name)
            tmp_path = tmp_file.name

        try:
            svg_output_path = tmp_path.replace(".png", ".svg")
            vtracer.convert_image_to_svg_py(tmp_path, svg_output_path)

            with open(svg_output_path, 'r', encoding='utf-8') as f:
                svg_data = f.read()
        finally:
            if os.path.exists(tmp_path): os.remove(tmp_path)
            if 'svg_output_path' in locals() and os.path.exists(svg_output_path): os.remove(svg_output_path)

        print("SVG extraction complete.")
        return svg_data

    def process(self, prompt: str, negative_prompt: str) -> str:
        generated_image = self._generate_image(prompt, negative_prompt)
        sketch_image = self._convert_to_sketch(generated_image)
        svg_content = self._extract_svg(sketch_image)
        return svg_content

app = Flask(__name__)

CORS(app, resources={r"/*": {"origins": "*"}})

pipeline = ImageToSvgPipeline(sketch_model_path=SKETCH_MODEL_WEIGHTS)

def sanitize_filename(text):
    text = re.sub(r'[\\/*?:"<>|]', "", text)
    return text.strip()

@app.route('/generate', methods=['POST'])
def generate_svg():
    data = request.json
    prompt = data.get('prompt')
    if not prompt: return jsonify({"error": "Prompt is required"}), 400

    negative_prompt = "ไฝŽๅ“่ณชใ€ๆœ€ๆ‚ชใฎๅ“่ณชใ€ไธ‹ๆ‰‹ใชๆ‰‹ใ€ๆŒ‡ใŒ6ๆœฌใ€ๆŒ‡ใŒ4ๆœฌใ€ๅฅ‡ๅฝขใ€้†œใ„ใ€ใผใ‚„ใ‘ใฆใ„ใ‚‹ใ€ใผใ‚„ใ‘ใŸใ€ใ‚ฆใ‚ฉใƒผใ‚ฟใƒผใƒžใƒผใ‚ฏใ€็ฝฒๅใ€ใƒ†ใ‚ญใ‚นใƒˆ"
    try:
        svg_result = pipeline.process(prompt, negative_prompt)

        timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
        safe_prompt = sanitize_filename(prompt)[:50]
        filename = f"{timestamp}_{safe_prompt}.svg"

        svg_path = os.path.join(SVG_DIR, filename)
        with open(svg_path, 'w', encoding='utf-8') as f:
            f.write(svg_result)

        thumbnail_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png'))
        cairosvg.svg2png(bytestring=svg_result.encode('utf-8'), write_to=thumbnail_path, output_width=256, output_height=256)

        return svg_result, 200, {'Content-Type': 'image/svg+xml'}
    except Exception as e:
        print(f"An error occurred during generation: {e}")
        return jsonify({"error": str(e)}), 500

@app.route('/gallery', methods=['GET'])
def get_gallery():
    try:
        page = int(request.args.get('page', 1))
        limit = int(request.args.get('limit', 8))

        svg_files = sorted([f for f in os.listdir(SVG_DIR) if f.endswith('.svg')], reverse=True)

        start_index = (page - 1) * limit
        end_index = start_index + limit
        paginated_files = svg_files[start_index:end_index]

        drawings = []
        for filename in paginated_files:
            prompt_match = re.match(r"\d+_(.+)\.svg", filename)
            prompt = prompt_match.group(1).replace('_', ' ') if prompt_match else "Prompt not found"
            drawings.append({
                "filename": filename,
                "thumbnail": f"/thumbnails/{filename.replace('.svg', '.png')}",
                "prompt": prompt
            })

        has_more = end_index < len(svg_files)
        return jsonify({"drawings": drawings, "hasMore": has_more})
    except Exception as e:
        print(f"Error fetching gallery: {e}")
        return jsonify({"error": "Failed to fetch gallery"}), 500

@app.route('/svgs/<path:filename>')
def get_svg(filename):
    return send_from_directory(SVG_DIR, filename)

@app.route('/thumbnails/<path:filename>')
def get_thumbnail(filename):
    return send_from_directory(THUMBNAIL_DIR, filename)

@app.route('/drawings/<path:filename>', methods=['DELETE'])
def delete_drawing_file(filename):
    try:
        svg_path = os.path.join(SVG_DIR, filename)
        thumb_path = os.path.join(THUMBNAIL_DIR, filename.replace('.svg', '.png'))
        if os.path.exists(svg_path): os.remove(svg_path)
        if os.path.exists(thumb_path): os.remove(thumb_path)
        return jsonify({"message": f"Successfully deleted {filename}"})
    except Exception as e:
        print(f"Error deleting file: {e}")
        return jsonify({"error": "Failed to delete file"}), 500

if __name__ == '__main__':
    print("Starting Flask server...")
    app.run(host='0.0.0.0', port=5000)