Linea / utils.py
potato
add utils.py
1670782
import os
import re
import numpy as np
import vtracer
import svgpathtools
import cairosvg
import io
import cv2
from lxml import etree
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import cdist
from python_tsp.heuristics import solve_tsp_local_search
from fast_tsp import find_tour
from svgpathtools import Path
from tqdm import tqdm
def parse_transform(transform_str):
if not transform_str: return np.eye(3)
matrix = np.eye(3)
numbers = r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?"
match = re.findall(r"matrix\(" + ",".join([numbers]*6) + r"\)", transform_str)
if match:
a, b, c, d, e, f = map(float, re.findall(numbers, match[0]))
m = np.array([[a, c, e], [b, d, f], [0, 0, 1]])
matrix = m @ matrix
match = re.findall(r"translate\(([^)]+)\)", transform_str)
if match:
parts = [float(v) for v in re.findall(numbers, match[0])]
tx, ty = parts if len(parts) == 2 else (parts[0], 0)
m = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
matrix = m @ matrix
match = re.findall(r"scale\(([^)]+)\)", transform_str)
if match:
parts = [float(v) for v in re.findall(numbers, match[0])]
sx, sy = parts if len(parts) == 2 else (parts[0], parts[0])
m = np.array([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
matrix = m @ matrix
return matrix
def get_global_transform(element):
transform = np.eye(3)
while element is not None:
t = element.get("transform")
if t: transform = parse_transform(t) @ transform
element = element.getparent()
return transform
def get_transformed_paths_and_coords(svg_content, mode):
"""
Parses the SVG, extracts path elements and their global transforms,
and returns data structures for sequencing and later rendering.
"""
parser = etree.XMLParser(remove_blank_text=True)
if mode == 'file':
svg_content = io.BytesIO(svg_content.encode('utf-8'))
tree = etree.parse(svg_content, parser)
root = tree.getroot()
width_str = root.get("width")
height_str = root.get("height")
viewBox = root.get("viewBox")
path_elements = root.findall(".//{*}path")
if not viewBox and width_str and height_str:
# Remove potential units like 'px' to get clean numbers
width = re.sub(r'[a-zA-Z%]', '', width_str)
height = re.sub(r'[a-zA-Z%]', '', height_str)
viewBox = f"0 0 {width} {height}"
print(f"SVG missing viewBox. Created a default: '{viewBox}'")
width = root.get("width")
height = root.get("height")
transformed_coords = []
paths_data = []
for elem in path_elements:
d_string = elem.get('d')
if not d_string:
continue
path = svgpathtools.parse_path(d_string)
transform = get_global_transform(elem)
start_vec = np.array([[path.start.real], [path.start.imag], [1]])
transformed_start = transform @ start_vec
coord = (transformed_start[0, 0], transformed_start[1, 0])
transformed_coords.append(coord)
paths_data.append({
'path': path,
'transform': transform,
'element': elem,
'coord': coord,
})
print(f"Extracted {len(transformed_coords)} paths with their elements.")
return paths_data, np.array(transformed_coords), width, height, viewBox
def transform_path(path, matrix):
"""Apply a 3x3 numpy transform to an svgpathtools Path object."""
new_segments = []
for seg in path:
start = np.array([[seg.start.real], [seg.start.imag], [1]])
end = np.array([[seg.end.real], [seg.end.imag], [1]])
start_t = matrix @ start
end_t = matrix @ end
seg.start = complex(start_t[0,0], start_t[1,0])
seg.end = complex(end_t[0,0], end_t[1,0])
new_segments.append(seg)
return Path(*new_segments)
def sequence_strokes(paths_data, coords, proximity_threshold=40):
"""
Clusters strokes by proximity and then finds the optimal drawing order
"""
if len(coords) < 2:
print("Fewer than 2 strokes, no sequencing needed.")
return paths_data
print("Clustering strokes by proximity...")
Z = linkage(coords, method='ward')
labels = fcluster(Z, t=proximity_threshold, criterion='distance')
num_clusters = len(set(labels))
print(f"{num_clusters} clusters detected.")
if num_clusters <= 1:
print("All strokes are in a single cluster, no reordering needed.")
return paths_data
clusters = {i: {'paths_data': [], 'coords': []} for i in range(1, num_clusters + 1)}
for i, label in enumerate(labels):
clusters[label]['paths_data'].append(paths_data[i])
clusters[label]['coords'].append(coords[i])
centroids = [np.mean(c['coords'], axis=0) for c in clusters.values()]
print("Solving TSP for optimal cluster drawing order...")
distance_matrix_float = cdist(centroids, centroids)
integer_distance_matrix = distance_matrix_float.astype(np.int32).tolist()
permutation = find_tour(integer_distance_matrix)
final_sequence = []
for cluster_idx in permutation:
cluster_label = cluster_idx + 1
final_sequence.extend(clusters[cluster_label]['paths_data'])
print("Final stroke sequence created.")
return final_sequence
def serialize_paths(paths_data):
serialized = []
for data in paths_data:
elem = data['element']
style = elem.get("style", "")
fill_match = re.search(r'fill:\s*([^;]+)', style)
stroke_match = re.search(r'stroke:\s*([^;]+)', style)
serialized.append({
"d": elem.get("d"),
"transform": elem.get("transform"),
"fill": elem.get("fill") or (fill_match.group(1).strip() if fill_match else "#000000"),
})
return serialized
def process_svg(svg_content, mode):
paths_data, coords, width, height, viewBox = get_transformed_paths_and_coords(svg_content, mode)
for data in paths_data:
path = data['path']
xmin, xmax, ymin, ymax = path.bbox()
area = (xmax - xmin) * (ymax - ymin)
data['area'] = area
areas = [p['area'] for p in paths_data]
areas_sorted = sorted(areas)
median_index = len(areas_sorted) // 2
layer_index = int(len(areas_sorted) * 0.99)
areaThreshold = areas_sorted[median_index]
layerThreshold = areas_sorted[layer_index]
fill_strokes = []
detail_strokes = []
layer_strokes = []
for data in paths_data:
if data['area'] >= layerThreshold:
layer_strokes.append(data)
elif data['area'] >= areaThreshold:
fill_strokes.append(data)
else:
detail_strokes.append(data)
ordered_fills = sequence_strokes(fill_strokes, np.array([p['coord'] for p in fill_strokes]))
ordered_details = sequence_strokes(detail_strokes, np.array([p['coord'] for p in detail_strokes]))
return {
"width": width,
"height": height,
"viewbox": viewBox,
"layers": serialize_paths(layer_strokes),
"fills": serialize_paths(ordered_fills),
"details": serialize_paths(ordered_details)
}