add Bethany tool

This commit is contained in:
Yuki-Kokomi
2024-08-16 17:50:51 +08:00
parent be3615ab12
commit c0326ca5eb
37 changed files with 19245 additions and 0 deletions

263
Bethany/lib/sketch.py Normal file
View File

@@ -0,0 +1,263 @@
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
from .curves import *
from .macro import *
########################## base ###########################
class SketchBase(object):
"""Base class for sketch (a collection of curves). """
def __init__(self, children, reorder=True):
self.children = children
if reorder:
self.reorder()
@staticmethod
def from_dict(stat):
"""construct sketch from json data
Args:
stat (dict): dict from json data
"""
raise NotImplementedError
@staticmethod
def from_vector(vec, start_point, is_numerical=True):
"""construct sketch from vector representation
Args:
vec (np.array): (seq_len, n_args)
start_point (np.array): (2, ). If none, implicitly defined as the last end point.
"""
raise NotImplementedError
def reorder(self):
"""rearrange the curves to follow counter-clockwise direction"""
raise NotImplementedError
@property
def start_point(self):
return self.children[0].start_point
@property
def end_point(self):
return self.children[-1].end_point
@property
def bbox(self):
"""compute bounding box (min/max points) of the sketch"""
all_points = np.concatenate([child.bbox for child in self.children], axis=0)
return np.stack([np.min(all_points, axis=0), np.max(all_points, axis=0)], axis=0)
@property
def bbox_size(self):
"""compute bounding box size (max of height and width)"""
bbox_min, bbox_max = self.bbox[0], self.bbox[1]
bbox_size = np.max(np.abs(np.concatenate([bbox_max - self.start_point, bbox_min - self.start_point])))
return bbox_size
@property
def global_trans(self):
"""start point + sketch size (bbox_size)"""
return np.concatenate([self.start_point, np.array([self.bbox_size])])
def transform(self, translate, scale):
"""linear transformation"""
for child in self.children:
child.transform(translate, scale)
def flip(self, axis):
for child in self.children:
child.flip(axis)
self.reorder()
def numericalize(self, n=256):
"""quantize curve parameters into integers"""
for child in self.children:
child.numericalize(n)
def normalize(self, size=256):
"""normalize within the given size, with start_point in the middle center"""
cur_size = self.bbox_size
scale = (size / 2 * NORM_FACTOR - 1) / cur_size # prevent potential overflow if data augmentation applied
#self.transform(-self.start_point, scale)
#self.transform(np.array((size / 2, size / 2)), 1)
def denormalize(self, bbox_size, size=256):
"""inverse procedure of normalize method"""
scale = bbox_size / (size / 2 * NORM_FACTOR - 1)
#self.transform(-np.array((size / 2, size / 2)), scale)
def to_vector(self):
"""convert to vector representation"""
raise NotImplementedError
def draw(self, ax):
"""draw sketch on matplotlib ax"""
raise NotImplementedError
def to_image(self):
"""convert to image"""
fig, ax = plt.subplots()
self.draw(ax)
ax.axis('equal')
fig.canvas.draw()
X = np.array(fig.canvas.renderer.buffer_rgba())[:, :, :3]
plt.close(fig)
return X
def sample_points(self, n=32):
"""uniformly sample points from the sketch"""
raise NotImplementedError
####################### loop & profile #######################
class Loop(SketchBase):
"""Sketch loop, a sequence of connected curves."""
@staticmethod
def from_dict(stat):
all_curves = [construct_curve_from_dict(item) for item in stat['profile_curves']]
this_loop = Loop(all_curves)
this_loop.is_outer = stat['is_outer']
return this_loop
def __str__(self):
return "Loop:" + "\n -" + "\n -".join([str(curve) for curve in self.children])
@staticmethod
def from_vector(vec, start_point=None, is_numerical=True):
all_curves = []
if start_point is None:
# FIXME: explicit for loop can be avoided here
for i in range(vec.shape[0]):
if vec[i][0] == EOS_IDX:
start_point = vec[i - 1][1:3]
break
for i in range(vec.shape[0]):
type = vec[i][0]
if type == SOL_IDX:
continue
elif type == EOS_IDX:
break
else:
curve = construct_curve_from_vector(vec[i], start_point, is_numerical=is_numerical)
start_point = vec[i][1:3] # current curve's end_point serves as next curve's start_point
all_curves.append(curve)
return Loop(all_curves)
def reorder(self):
"""reorder by starting left most and counter-clockwise"""
if len(self.children) <= 1:
return
start_curve_idx = -1
sx, sy = 10000, 10000
# correct start-end point order
if np.allclose(self.children[0].start_point, self.children[1].start_point) or \
np.allclose(self.children[0].start_point, self.children[1].end_point):
self.children[0].reverse()
# correct start-end point order and find left-most point
for i, curve in enumerate(self.children):
if i < len(self.children) - 1 and np.allclose(curve.end_point, self.children[i + 1].end_point):
self.children[i + 1].reverse()
if round(curve.start_point[0], 6) < round(sx, 6) or \
(round(curve.start_point[0], 6) == round(sx, 6) and round(curve.start_point[1], 6) < round(sy, 6)):
start_curve_idx = i
sx, sy = curve.start_point
self.children = self.children[start_curve_idx:] + self.children[:start_curve_idx]
# ensure mostly counter-clock wise
if isinstance(self.children[0], Circle) or isinstance(self.children[-1], Circle): # FIXME: hard-coded
return
start_vec = self.children[0].direction()
end_vec = self.children[-1].direction(from_start=False)
if np.cross(end_vec, start_vec) <= 0:
for curve in self.children:
curve.reverse()
self.children.reverse()
def to_vector(self, max_len=None, add_sol=True, add_eos=True):
loop_vec = np.stack([curve.to_vector() for curve in self.children], axis=0)
if add_sol:
loop_vec = np.concatenate([SOL_VEC[np.newaxis], loop_vec], axis=0)
if add_eos:
loop_vec = np.concatenate([loop_vec, EOS_VEC[np.newaxis]], axis=0)
if max_len is None:
return loop_vec
if loop_vec.shape[0] > max_len:
return None
elif loop_vec.shape[0] < max_len:
pad_vec = np.tile(EOS_VEC, max_len - loop_vec.shape[0]).reshape((-1, len(EOS_VEC)))
loop_vec = np.concatenate([loop_vec, pad_vec], axis=0) # (max_len, 1 + N_ARGS)
return loop_vec
def draw(self, ax):
colors = ['red', 'blue', 'green', 'brown', 'pink', 'yellow', 'purple', 'black'] * 10
for i, curve in enumerate(self.children):
curve.draw(ax, colors[i])
def sample_points(self, n=32):
points = np.stack([curve.sample_points(n) for curve in self.children], axis=0) # (n_curves, n, 2)
return points
class Profile(SketchBase):
"""Sketch profilea closed region formed by one or more loops.
The outer-most loop is placed at first."""
@staticmethod
def from_dict(stat):
all_loops = [Loop.from_dict(item) for item in stat['loops']]
return Profile(all_loops)
def __str__(self):
return "Profile:" + "\n -".join([str(loop) for loop in self.children])
@staticmethod
def from_vector(vec, start_point=None, is_numerical=True):
all_loops = []
command = vec[:, 0]
end_idx = command.tolist().index(EOS_IDX)
indices = np.where(command[:end_idx] == SOL_IDX)[0].tolist() + [end_idx]
for i in range(len(indices) - 1):
loop_vec = vec[indices[i]:indices[i + 1]]
loop_vec = np.concatenate([loop_vec, EOS_VEC[np.newaxis]], axis=0)
if loop_vec[0][0] == SOL_IDX and loop_vec[1][0] not in [SOL_IDX, EOS_IDX]:
all_loops.append(Loop.from_vector(loop_vec, is_numerical=is_numerical))
return Profile(all_loops)
def reorder(self):
if len(self.children) <= 1:
return
all_loops_bbox_min = np.stack([loop.bbox[0] for loop in self.children], axis=0).round(6)
ind = np.lexsort(all_loops_bbox_min.transpose()[[1, 0]])
self.children = [self.children[i] for i in ind]
def draw(self, ax):
for i, loop in enumerate(self.children):
loop.draw(ax)
ax.text(loop.start_point[0], loop.start_point[1], str(i))
def to_vector(self, max_n_loops=None, max_len_loop=None, pad=True):
loop_vecs = [loop.to_vector(None, add_eos=False) for loop in self.children]
if max_n_loops is not None and len(loop_vecs) > max_n_loops:
return None
for vec in loop_vecs:
if max_len_loop is not None and vec.shape[0] > max_len_loop:
return None
profile_vec = np.concatenate(loop_vecs, axis=0)
profile_vec = np.concatenate([profile_vec, EOS_VEC[np.newaxis]], axis=0)
if pad:
pad_len = max_n_loops * max_len_loop - profile_vec.shape[0]
profile_vec = np.concatenate([profile_vec, EOS_VEC[np.newaxis].repeat(pad_len, axis=0)], axis=0)
return profile_vec
def sample_points(self, n=32):
points = np.concatenate([loop.sample_points(n) for loop in self.children], axis=0)
return points