Files
OpenECADv2toSTEP/lib/sketch.py
Yuki-Kokomi 7b6e0703c2 first commit
2024-08-07 10:07:38 +08:00

264 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
from .curves import *
from .macro import *
########################## base ###########################
class SketchBase(object):
"""Base class for sketch (a collection of curves). """
def __init__(self, children, reorder=True):
self.children = children
if reorder:
self.reorder()
@staticmethod
def from_dict(stat):
"""construct sketch from json data
Args:
stat (dict): dict from json data
"""
raise NotImplementedError
@staticmethod
def from_vector(vec, start_point, is_numerical=True):
"""construct sketch from vector representation
Args:
vec (np.array): (seq_len, n_args)
start_point (np.array): (2, ). If none, implicitly defined as the last end point.
"""
raise NotImplementedError
def reorder(self):
"""rearrange the curves to follow counter-clockwise direction"""
raise NotImplementedError
@property
def start_point(self):
return self.children[0].start_point
@property
def end_point(self):
return self.children[-1].end_point
@property
def bbox(self):
"""compute bounding box (min/max points) of the sketch"""
all_points = np.concatenate([child.bbox for child in self.children], axis=0)
return np.stack([np.min(all_points, axis=0), np.max(all_points, axis=0)], axis=0)
@property
def bbox_size(self):
"""compute bounding box size (max of height and width)"""
bbox_min, bbox_max = self.bbox[0], self.bbox[1]
bbox_size = np.max(np.abs(np.concatenate([bbox_max - self.start_point, bbox_min - self.start_point])))
return bbox_size
@property
def global_trans(self):
"""start point + sketch size (bbox_size)"""
return np.concatenate([self.start_point, np.array([self.bbox_size])])
def transform(self, translate, scale):
"""linear transformation"""
for child in self.children:
child.transform(translate, scale)
def flip(self, axis):
for child in self.children:
child.flip(axis)
self.reorder()
def numericalize(self, n=256):
"""quantize curve parameters into integers"""
for child in self.children:
child.numericalize(n)
def normalize(self, size=256):
"""normalize within the given size, with start_point in the middle center"""
cur_size = self.bbox_size
scale = (size / 2 * NORM_FACTOR - 1) / cur_size # prevent potential overflow if data augmentation applied
#self.transform(-self.start_point, scale)
#self.transform(np.array((size / 2, size / 2)), 1)
def denormalize(self, bbox_size, size=256):
"""inverse procedure of normalize method"""
scale = bbox_size / (size / 2 * NORM_FACTOR - 1)
#self.transform(-np.array((size / 2, size / 2)), scale)
def to_vector(self):
"""convert to vector representation"""
raise NotImplementedError
def draw(self, ax):
"""draw sketch on matplotlib ax"""
raise NotImplementedError
def to_image(self):
"""convert to image"""
fig, ax = plt.subplots()
self.draw(ax)
ax.axis('equal')
fig.canvas.draw()
X = np.array(fig.canvas.renderer.buffer_rgba())[:, :, :3]
plt.close(fig)
return X
def sample_points(self, n=32):
"""uniformly sample points from the sketch"""
raise NotImplementedError
####################### loop & profile #######################
class Loop(SketchBase):
"""Sketch loop, a sequence of connected curves."""
@staticmethod
def from_dict(stat):
all_curves = [construct_curve_from_dict(item) for item in stat['profile_curves']]
this_loop = Loop(all_curves)
this_loop.is_outer = stat['is_outer']
return this_loop
def __str__(self):
return "Loop:" + "\n -" + "\n -".join([str(curve) for curve in self.children])
@staticmethod
def from_vector(vec, start_point=None, is_numerical=True):
all_curves = []
if start_point is None:
# FIXME: explicit for loop can be avoided here
for i in range(vec.shape[0]):
if vec[i][0] == EOS_IDX:
start_point = vec[i - 1][1:3]
break
for i in range(vec.shape[0]):
type = vec[i][0]
if type == SOL_IDX:
continue
elif type == EOS_IDX:
break
else:
curve = construct_curve_from_vector(vec[i], start_point, is_numerical=is_numerical)
start_point = vec[i][1:3] # current curve's end_point serves as next curve's start_point
all_curves.append(curve)
return Loop(all_curves)
def reorder(self):
"""reorder by starting left most and counter-clockwise"""
if len(self.children) <= 1:
return
start_curve_idx = -1
sx, sy = 10000, 10000
# correct start-end point order
if np.allclose(self.children[0].start_point, self.children[1].start_point) or \
np.allclose(self.children[0].start_point, self.children[1].end_point):
self.children[0].reverse()
# correct start-end point order and find left-most point
for i, curve in enumerate(self.children):
if i < len(self.children) - 1 and np.allclose(curve.end_point, self.children[i + 1].end_point):
self.children[i + 1].reverse()
if round(curve.start_point[0], 6) < round(sx, 6) or \
(round(curve.start_point[0], 6) == round(sx, 6) and round(curve.start_point[1], 6) < round(sy, 6)):
start_curve_idx = i
sx, sy = curve.start_point
self.children = self.children[start_curve_idx:] + self.children[:start_curve_idx]
# ensure mostly counter-clock wise
if isinstance(self.children[0], Circle) or isinstance(self.children[-1], Circle): # FIXME: hard-coded
return
start_vec = self.children[0].direction()
end_vec = self.children[-1].direction(from_start=False)
if np.cross(end_vec, start_vec) <= 0:
for curve in self.children:
curve.reverse()
self.children.reverse()
def to_vector(self, max_len=None, add_sol=True, add_eos=True):
loop_vec = np.stack([curve.to_vector() for curve in self.children], axis=0)
if add_sol:
loop_vec = np.concatenate([SOL_VEC[np.newaxis], loop_vec], axis=0)
if add_eos:
loop_vec = np.concatenate([loop_vec, EOS_VEC[np.newaxis]], axis=0)
if max_len is None:
return loop_vec
if loop_vec.shape[0] > max_len:
return None
elif loop_vec.shape[0] < max_len:
pad_vec = np.tile(EOS_VEC, max_len - loop_vec.shape[0]).reshape((-1, len(EOS_VEC)))
loop_vec = np.concatenate([loop_vec, pad_vec], axis=0) # (max_len, 1 + N_ARGS)
return loop_vec
def draw(self, ax):
colors = ['red', 'blue', 'green', 'brown', 'pink', 'yellow', 'purple', 'black'] * 10
for i, curve in enumerate(self.children):
curve.draw(ax, colors[i])
def sample_points(self, n=32):
points = np.stack([curve.sample_points(n) for curve in self.children], axis=0) # (n_curves, n, 2)
return points
class Profile(SketchBase):
"""Sketch profilea closed region formed by one or more loops.
The outer-most loop is placed at first."""
@staticmethod
def from_dict(stat):
all_loops = [Loop.from_dict(item) for item in stat['loops']]
return Profile(all_loops)
def __str__(self):
return "Profile:" + "\n -".join([str(loop) for loop in self.children])
@staticmethod
def from_vector(vec, start_point=None, is_numerical=True):
all_loops = []
command = vec[:, 0]
end_idx = command.tolist().index(EOS_IDX)
indices = np.where(command[:end_idx] == SOL_IDX)[0].tolist() + [end_idx]
for i in range(len(indices) - 1):
loop_vec = vec[indices[i]:indices[i + 1]]
loop_vec = np.concatenate([loop_vec, EOS_VEC[np.newaxis]], axis=0)
if loop_vec[0][0] == SOL_IDX and loop_vec[1][0] not in [SOL_IDX, EOS_IDX]:
all_loops.append(Loop.from_vector(loop_vec, is_numerical=is_numerical))
return Profile(all_loops)
def reorder(self):
if len(self.children) <= 1:
return
all_loops_bbox_min = np.stack([loop.bbox[0] for loop in self.children], axis=0).round(6)
ind = np.lexsort(all_loops_bbox_min.transpose()[[1, 0]])
self.children = [self.children[i] for i in ind]
def draw(self, ax):
for i, loop in enumerate(self.children):
loop.draw(ax)
ax.text(loop.start_point[0], loop.start_point[1], str(i))
def to_vector(self, max_n_loops=None, max_len_loop=None, pad=True):
loop_vecs = [loop.to_vector(None, add_eos=False) for loop in self.children]
if max_n_loops is not None and len(loop_vecs) > max_n_loops:
return None
for vec in loop_vecs:
if max_len_loop is not None and vec.shape[0] > max_len_loop:
return None
profile_vec = np.concatenate(loop_vecs, axis=0)
profile_vec = np.concatenate([profile_vec, EOS_VEC[np.newaxis]], axis=0)
if pad:
pad_len = max_n_loops * max_len_loop - profile_vec.shape[0]
profile_vec = np.concatenate([profile_vec, EOS_VEC[np.newaxis].repeat(pad_len, axis=0)], axis=0)
return profile_vec
def sample_points(self, n=32):
points = np.concatenate([loop.sample_points(n) for loop in self.children], axis=0)
return points