Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import numpy as np | |
| import vtracer | |
| import svgpathtools | |
| import cairosvg | |
| import io | |
| import cv2 | |
| from lxml import etree | |
| from scipy.cluster.hierarchy import linkage, fcluster | |
| from scipy.spatial.distance import cdist | |
| from python_tsp.heuristics import solve_tsp_local_search | |
| from fast_tsp import find_tour | |
| from svgpathtools import Path | |
| from tqdm import tqdm | |
| def parse_transform(transform_str): | |
| if not transform_str: return np.eye(3) | |
| matrix = np.eye(3) | |
| numbers = r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?" | |
| match = re.findall(r"matrix\(" + ",".join([numbers]*6) + r"\)", transform_str) | |
| if match: | |
| a, b, c, d, e, f = map(float, re.findall(numbers, match[0])) | |
| m = np.array([[a, c, e], [b, d, f], [0, 0, 1]]) | |
| matrix = m @ matrix | |
| match = re.findall(r"translate\(([^)]+)\)", transform_str) | |
| if match: | |
| parts = [float(v) for v in re.findall(numbers, match[0])] | |
| tx, ty = parts if len(parts) == 2 else (parts[0], 0) | |
| m = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) | |
| matrix = m @ matrix | |
| match = re.findall(r"scale\(([^)]+)\)", transform_str) | |
| if match: | |
| parts = [float(v) for v in re.findall(numbers, match[0])] | |
| sx, sy = parts if len(parts) == 2 else (parts[0], parts[0]) | |
| m = np.array([[sx, 0, 0], [0, sy, 0], [0, 0, 1]]) | |
| matrix = m @ matrix | |
| return matrix | |
| def get_global_transform(element): | |
| transform = np.eye(3) | |
| while element is not None: | |
| t = element.get("transform") | |
| if t: transform = parse_transform(t) @ transform | |
| element = element.getparent() | |
| return transform | |
| def get_transformed_paths_and_coords(svg_content, mode): | |
| """ | |
| Parses the SVG, extracts path elements and their global transforms, | |
| and returns data structures for sequencing and later rendering. | |
| """ | |
| parser = etree.XMLParser(remove_blank_text=True) | |
| if mode == 'file': | |
| svg_content = io.BytesIO(svg_content.encode('utf-8')) | |
| tree = etree.parse(svg_content, parser) | |
| root = tree.getroot() | |
| width_str = root.get("width") | |
| height_str = root.get("height") | |
| viewBox = root.get("viewBox") | |
| path_elements = root.findall(".//{*}path") | |
| if not viewBox and width_str and height_str: | |
| # Remove potential units like 'px' to get clean numbers | |
| width = re.sub(r'[a-zA-Z%]', '', width_str) | |
| height = re.sub(r'[a-zA-Z%]', '', height_str) | |
| viewBox = f"0 0 {width} {height}" | |
| print(f"SVG missing viewBox. Created a default: '{viewBox}'") | |
| width = root.get("width") | |
| height = root.get("height") | |
| transformed_coords = [] | |
| paths_data = [] | |
| for elem in path_elements: | |
| d_string = elem.get('d') | |
| if not d_string: | |
| continue | |
| path = svgpathtools.parse_path(d_string) | |
| transform = get_global_transform(elem) | |
| start_vec = np.array([[path.start.real], [path.start.imag], [1]]) | |
| transformed_start = transform @ start_vec | |
| coord = (transformed_start[0, 0], transformed_start[1, 0]) | |
| transformed_coords.append(coord) | |
| paths_data.append({ | |
| 'path': path, | |
| 'transform': transform, | |
| 'element': elem, | |
| 'coord': coord, | |
| }) | |
| print(f"Extracted {len(transformed_coords)} paths with their elements.") | |
| return paths_data, np.array(transformed_coords), width, height, viewBox | |
| def transform_path(path, matrix): | |
| """Apply a 3x3 numpy transform to an svgpathtools Path object.""" | |
| new_segments = [] | |
| for seg in path: | |
| start = np.array([[seg.start.real], [seg.start.imag], [1]]) | |
| end = np.array([[seg.end.real], [seg.end.imag], [1]]) | |
| start_t = matrix @ start | |
| end_t = matrix @ end | |
| seg.start = complex(start_t[0,0], start_t[1,0]) | |
| seg.end = complex(end_t[0,0], end_t[1,0]) | |
| new_segments.append(seg) | |
| return Path(*new_segments) | |
| def sequence_strokes(paths_data, coords, proximity_threshold=40): | |
| """ | |
| Clusters strokes by proximity and then finds the optimal drawing order | |
| """ | |
| if len(coords) < 2: | |
| print("Fewer than 2 strokes, no sequencing needed.") | |
| return paths_data | |
| print("Clustering strokes by proximity...") | |
| Z = linkage(coords, method='ward') | |
| labels = fcluster(Z, t=proximity_threshold, criterion='distance') | |
| num_clusters = len(set(labels)) | |
| print(f"{num_clusters} clusters detected.") | |
| if num_clusters <= 1: | |
| print("All strokes are in a single cluster, no reordering needed.") | |
| return paths_data | |
| clusters = {i: {'paths_data': [], 'coords': []} for i in range(1, num_clusters + 1)} | |
| for i, label in enumerate(labels): | |
| clusters[label]['paths_data'].append(paths_data[i]) | |
| clusters[label]['coords'].append(coords[i]) | |
| centroids = [np.mean(c['coords'], axis=0) for c in clusters.values()] | |
| print("Solving TSP for optimal cluster drawing order...") | |
| distance_matrix_float = cdist(centroids, centroids) | |
| integer_distance_matrix = distance_matrix_float.astype(np.int32).tolist() | |
| permutation = find_tour(integer_distance_matrix) | |
| final_sequence = [] | |
| for cluster_idx in permutation: | |
| cluster_label = cluster_idx + 1 | |
| final_sequence.extend(clusters[cluster_label]['paths_data']) | |
| print("Final stroke sequence created.") | |
| return final_sequence | |
| def serialize_paths(paths_data): | |
| serialized = [] | |
| for data in paths_data: | |
| elem = data['element'] | |
| style = elem.get("style", "") | |
| fill_match = re.search(r'fill:\s*([^;]+)', style) | |
| stroke_match = re.search(r'stroke:\s*([^;]+)', style) | |
| serialized.append({ | |
| "d": elem.get("d"), | |
| "transform": elem.get("transform"), | |
| "fill": elem.get("fill") or (fill_match.group(1).strip() if fill_match else "#000000"), | |
| }) | |
| return serialized | |
| def process_svg(svg_content, mode): | |
| paths_data, coords, width, height, viewBox = get_transformed_paths_and_coords(svg_content, mode) | |
| for data in paths_data: | |
| path = data['path'] | |
| xmin, xmax, ymin, ymax = path.bbox() | |
| area = (xmax - xmin) * (ymax - ymin) | |
| data['area'] = area | |
| areas = [p['area'] for p in paths_data] | |
| areas_sorted = sorted(areas) | |
| median_index = len(areas_sorted) // 2 | |
| layer_index = int(len(areas_sorted) * 0.99) | |
| areaThreshold = areas_sorted[median_index] | |
| layerThreshold = areas_sorted[layer_index] | |
| fill_strokes = [] | |
| detail_strokes = [] | |
| layer_strokes = [] | |
| for data in paths_data: | |
| if data['area'] >= layerThreshold: | |
| layer_strokes.append(data) | |
| elif data['area'] >= areaThreshold: | |
| fill_strokes.append(data) | |
| else: | |
| detail_strokes.append(data) | |
| ordered_fills = sequence_strokes(fill_strokes, np.array([p['coord'] for p in fill_strokes])) | |
| ordered_details = sequence_strokes(detail_strokes, np.array([p['coord'] for p in detail_strokes])) | |
| return { | |
| "width": width, | |
| "height": height, | |
| "viewbox": viewBox, | |
| "layers": serialize_paths(layer_strokes), | |
| "fills": serialize_paths(ordered_fills), | |
| "details": serialize_paths(ordered_details) | |
| } | |