Spaces:
Runtime error
Runtime error
File size: 7,236 Bytes
1670782 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import os
import re
import numpy as np
import vtracer
import svgpathtools
import cairosvg
import io
import cv2
from lxml import etree
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import cdist
from python_tsp.heuristics import solve_tsp_local_search
from fast_tsp import find_tour
from svgpathtools import Path
from tqdm import tqdm
def parse_transform(transform_str):
if not transform_str: return np.eye(3)
matrix = np.eye(3)
numbers = r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?"
match = re.findall(r"matrix\(" + ",".join([numbers]*6) + r"\)", transform_str)
if match:
a, b, c, d, e, f = map(float, re.findall(numbers, match[0]))
m = np.array([[a, c, e], [b, d, f], [0, 0, 1]])
matrix = m @ matrix
match = re.findall(r"translate\(([^)]+)\)", transform_str)
if match:
parts = [float(v) for v in re.findall(numbers, match[0])]
tx, ty = parts if len(parts) == 2 else (parts[0], 0)
m = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
matrix = m @ matrix
match = re.findall(r"scale\(([^)]+)\)", transform_str)
if match:
parts = [float(v) for v in re.findall(numbers, match[0])]
sx, sy = parts if len(parts) == 2 else (parts[0], parts[0])
m = np.array([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
matrix = m @ matrix
return matrix
def get_global_transform(element):
transform = np.eye(3)
while element is not None:
t = element.get("transform")
if t: transform = parse_transform(t) @ transform
element = element.getparent()
return transform
def get_transformed_paths_and_coords(svg_content, mode):
"""
Parses the SVG, extracts path elements and their global transforms,
and returns data structures for sequencing and later rendering.
"""
parser = etree.XMLParser(remove_blank_text=True)
if mode == 'file':
svg_content = io.BytesIO(svg_content.encode('utf-8'))
tree = etree.parse(svg_content, parser)
root = tree.getroot()
width_str = root.get("width")
height_str = root.get("height")
viewBox = root.get("viewBox")
path_elements = root.findall(".//{*}path")
if not viewBox and width_str and height_str:
# Remove potential units like 'px' to get clean numbers
width = re.sub(r'[a-zA-Z%]', '', width_str)
height = re.sub(r'[a-zA-Z%]', '', height_str)
viewBox = f"0 0 {width} {height}"
print(f"SVG missing viewBox. Created a default: '{viewBox}'")
width = root.get("width")
height = root.get("height")
transformed_coords = []
paths_data = []
for elem in path_elements:
d_string = elem.get('d')
if not d_string:
continue
path = svgpathtools.parse_path(d_string)
transform = get_global_transform(elem)
start_vec = np.array([[path.start.real], [path.start.imag], [1]])
transformed_start = transform @ start_vec
coord = (transformed_start[0, 0], transformed_start[1, 0])
transformed_coords.append(coord)
paths_data.append({
'path': path,
'transform': transform,
'element': elem,
'coord': coord,
})
print(f"Extracted {len(transformed_coords)} paths with their elements.")
return paths_data, np.array(transformed_coords), width, height, viewBox
def transform_path(path, matrix):
"""Apply a 3x3 numpy transform to an svgpathtools Path object."""
new_segments = []
for seg in path:
start = np.array([[seg.start.real], [seg.start.imag], [1]])
end = np.array([[seg.end.real], [seg.end.imag], [1]])
start_t = matrix @ start
end_t = matrix @ end
seg.start = complex(start_t[0,0], start_t[1,0])
seg.end = complex(end_t[0,0], end_t[1,0])
new_segments.append(seg)
return Path(*new_segments)
def sequence_strokes(paths_data, coords, proximity_threshold=40):
"""
Clusters strokes by proximity and then finds the optimal drawing order
"""
if len(coords) < 2:
print("Fewer than 2 strokes, no sequencing needed.")
return paths_data
print("Clustering strokes by proximity...")
Z = linkage(coords, method='ward')
labels = fcluster(Z, t=proximity_threshold, criterion='distance')
num_clusters = len(set(labels))
print(f"{num_clusters} clusters detected.")
if num_clusters <= 1:
print("All strokes are in a single cluster, no reordering needed.")
return paths_data
clusters = {i: {'paths_data': [], 'coords': []} for i in range(1, num_clusters + 1)}
for i, label in enumerate(labels):
clusters[label]['paths_data'].append(paths_data[i])
clusters[label]['coords'].append(coords[i])
centroids = [np.mean(c['coords'], axis=0) for c in clusters.values()]
print("Solving TSP for optimal cluster drawing order...")
distance_matrix_float = cdist(centroids, centroids)
integer_distance_matrix = distance_matrix_float.astype(np.int32).tolist()
permutation = find_tour(integer_distance_matrix)
final_sequence = []
for cluster_idx in permutation:
cluster_label = cluster_idx + 1
final_sequence.extend(clusters[cluster_label]['paths_data'])
print("Final stroke sequence created.")
return final_sequence
def serialize_paths(paths_data):
serialized = []
for data in paths_data:
elem = data['element']
style = elem.get("style", "")
fill_match = re.search(r'fill:\s*([^;]+)', style)
stroke_match = re.search(r'stroke:\s*([^;]+)', style)
serialized.append({
"d": elem.get("d"),
"transform": elem.get("transform"),
"fill": elem.get("fill") or (fill_match.group(1).strip() if fill_match else "#000000"),
})
return serialized
def process_svg(svg_content, mode):
paths_data, coords, width, height, viewBox = get_transformed_paths_and_coords(svg_content, mode)
for data in paths_data:
path = data['path']
xmin, xmax, ymin, ymax = path.bbox()
area = (xmax - xmin) * (ymax - ymin)
data['area'] = area
areas = [p['area'] for p in paths_data]
areas_sorted = sorted(areas)
median_index = len(areas_sorted) // 2
layer_index = int(len(areas_sorted) * 0.99)
areaThreshold = areas_sorted[median_index]
layerThreshold = areas_sorted[layer_index]
fill_strokes = []
detail_strokes = []
layer_strokes = []
for data in paths_data:
if data['area'] >= layerThreshold:
layer_strokes.append(data)
elif data['area'] >= areaThreshold:
fill_strokes.append(data)
else:
detail_strokes.append(data)
ordered_fills = sequence_strokes(fill_strokes, np.array([p['coord'] for p in fill_strokes]))
ordered_details = sequence_strokes(detail_strokes, np.array([p['coord'] for p in detail_strokes]))
return {
"width": width,
"height": height,
"viewbox": viewBox,
"layers": serialize_paths(layer_strokes),
"fills": serialize_paths(ordered_fills),
"details": serialize_paths(ordered_details)
}
|