first commit
This commit is contained in:
263
lib/sketch.py
Normal file
263
lib/sketch.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use('TkAgg')
|
||||
import matplotlib.pyplot as plt
|
||||
from .curves import *
|
||||
from .macro import *
|
||||
|
||||
|
||||
########################## base ###########################
|
||||
class SketchBase(object):
|
||||
"""Base class for sketch (a collection of curves). """
|
||||
def __init__(self, children, reorder=True):
|
||||
self.children = children
|
||||
|
||||
if reorder:
|
||||
self.reorder()
|
||||
|
||||
@staticmethod
|
||||
def from_dict(stat):
|
||||
"""construct sketch from json data
|
||||
|
||||
Args:
|
||||
stat (dict): dict from json data
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def from_vector(vec, start_point, is_numerical=True):
|
||||
"""construct sketch from vector representation
|
||||
|
||||
Args:
|
||||
vec (np.array): (seq_len, n_args)
|
||||
start_point (np.array): (2, ). If none, implicitly defined as the last end point.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def reorder(self):
|
||||
"""rearrange the curves to follow counter-clockwise direction"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def start_point(self):
|
||||
return self.children[0].start_point
|
||||
|
||||
@property
|
||||
def end_point(self):
|
||||
return self.children[-1].end_point
|
||||
|
||||
@property
|
||||
def bbox(self):
|
||||
"""compute bounding box (min/max points) of the sketch"""
|
||||
all_points = np.concatenate([child.bbox for child in self.children], axis=0)
|
||||
return np.stack([np.min(all_points, axis=0), np.max(all_points, axis=0)], axis=0)
|
||||
|
||||
@property
|
||||
def bbox_size(self):
|
||||
"""compute bounding box size (max of height and width)"""
|
||||
bbox_min, bbox_max = self.bbox[0], self.bbox[1]
|
||||
bbox_size = np.max(np.abs(np.concatenate([bbox_max - self.start_point, bbox_min - self.start_point])))
|
||||
return bbox_size
|
||||
|
||||
@property
|
||||
def global_trans(self):
|
||||
"""start point + sketch size (bbox_size)"""
|
||||
return np.concatenate([self.start_point, np.array([self.bbox_size])])
|
||||
|
||||
def transform(self, translate, scale):
|
||||
"""linear transformation"""
|
||||
for child in self.children:
|
||||
child.transform(translate, scale)
|
||||
|
||||
def flip(self, axis):
|
||||
for child in self.children:
|
||||
child.flip(axis)
|
||||
self.reorder()
|
||||
|
||||
def numericalize(self, n=256):
|
||||
"""quantize curve parameters into integers"""
|
||||
for child in self.children:
|
||||
child.numericalize(n)
|
||||
|
||||
def normalize(self, size=256):
|
||||
"""normalize within the given size, with start_point in the middle center"""
|
||||
cur_size = self.bbox_size
|
||||
scale = (size / 2 * NORM_FACTOR - 1) / cur_size # prevent potential overflow if data augmentation applied
|
||||
self.transform(-self.start_point, scale)
|
||||
self.transform(np.array((size / 2, size / 2)), 1)
|
||||
|
||||
def denormalize(self, bbox_size, size=256):
|
||||
"""inverse procedure of normalize method"""
|
||||
scale = bbox_size / (size / 2 * NORM_FACTOR - 1)
|
||||
self.transform(-np.array((size / 2, size / 2)), scale)
|
||||
|
||||
def to_vector(self):
|
||||
"""convert to vector representation"""
|
||||
raise NotImplementedError
|
||||
|
||||
def draw(self, ax):
|
||||
"""draw sketch on matplotlib ax"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_image(self):
|
||||
"""convert to image"""
|
||||
fig, ax = plt.subplots()
|
||||
self.draw(ax)
|
||||
ax.axis('equal')
|
||||
fig.canvas.draw()
|
||||
X = np.array(fig.canvas.renderer.buffer_rgba())[:, :, :3]
|
||||
plt.close(fig)
|
||||
return X
|
||||
|
||||
def sample_points(self, n=32):
|
||||
"""uniformly sample points from the sketch"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
####################### loop & profile #######################
|
||||
class Loop(SketchBase):
|
||||
"""Sketch loop, a sequence of connected curves."""
|
||||
@staticmethod
|
||||
def from_dict(stat):
|
||||
all_curves = [construct_curve_from_dict(item) for item in stat['profile_curves']]
|
||||
this_loop = Loop(all_curves)
|
||||
this_loop.is_outer = stat['is_outer']
|
||||
return this_loop
|
||||
|
||||
def __str__(self):
|
||||
return "Loop:" + "\n -" + "\n -".join([str(curve) for curve in self.children])
|
||||
|
||||
@staticmethod
|
||||
def from_vector(vec, start_point=None, is_numerical=True):
|
||||
all_curves = []
|
||||
if start_point is None:
|
||||
# FIXME: explicit for loop can be avoided here
|
||||
for i in range(vec.shape[0]):
|
||||
if vec[i][0] == EOS_IDX:
|
||||
start_point = vec[i - 1][1:3]
|
||||
break
|
||||
for i in range(vec.shape[0]):
|
||||
type = vec[i][0]
|
||||
if type == SOL_IDX:
|
||||
continue
|
||||
elif type == EOS_IDX:
|
||||
break
|
||||
else:
|
||||
curve = construct_curve_from_vector(vec[i], start_point, is_numerical=is_numerical)
|
||||
start_point = vec[i][1:3] # current curve's end_point serves as next curve's start_point
|
||||
all_curves.append(curve)
|
||||
return Loop(all_curves)
|
||||
|
||||
def reorder(self):
|
||||
"""reorder by starting left most and counter-clockwise"""
|
||||
if len(self.children) <= 1:
|
||||
return
|
||||
|
||||
start_curve_idx = -1
|
||||
sx, sy = 10000, 10000
|
||||
|
||||
# correct start-end point order
|
||||
if np.allclose(self.children[0].start_point, self.children[1].start_point) or \
|
||||
np.allclose(self.children[0].start_point, self.children[1].end_point):
|
||||
self.children[0].reverse()
|
||||
|
||||
# correct start-end point order and find left-most point
|
||||
for i, curve in enumerate(self.children):
|
||||
if i < len(self.children) - 1 and np.allclose(curve.end_point, self.children[i + 1].end_point):
|
||||
self.children[i + 1].reverse()
|
||||
if round(curve.start_point[0], 6) < round(sx, 6) or \
|
||||
(round(curve.start_point[0], 6) == round(sx, 6) and round(curve.start_point[1], 6) < round(sy, 6)):
|
||||
start_curve_idx = i
|
||||
sx, sy = curve.start_point
|
||||
|
||||
self.children = self.children[start_curve_idx:] + self.children[:start_curve_idx]
|
||||
|
||||
# ensure mostly counter-clock wise
|
||||
if isinstance(self.children[0], Circle) or isinstance(self.children[-1], Circle): # FIXME: hard-coded
|
||||
return
|
||||
start_vec = self.children[0].direction()
|
||||
end_vec = self.children[-1].direction(from_start=False)
|
||||
if np.cross(end_vec, start_vec) <= 0:
|
||||
for curve in self.children:
|
||||
curve.reverse()
|
||||
self.children.reverse()
|
||||
|
||||
def to_vector(self, max_len=None, add_sol=True, add_eos=True):
|
||||
loop_vec = np.stack([curve.to_vector() for curve in self.children], axis=0)
|
||||
if add_sol:
|
||||
loop_vec = np.concatenate([SOL_VEC[np.newaxis], loop_vec], axis=0)
|
||||
if add_eos:
|
||||
loop_vec = np.concatenate([loop_vec, EOS_VEC[np.newaxis]], axis=0)
|
||||
if max_len is None:
|
||||
return loop_vec
|
||||
|
||||
if loop_vec.shape[0] > max_len:
|
||||
return None
|
||||
elif loop_vec.shape[0] < max_len:
|
||||
pad_vec = np.tile(EOS_VEC, max_len - loop_vec.shape[0]).reshape((-1, len(EOS_VEC)))
|
||||
loop_vec = np.concatenate([loop_vec, pad_vec], axis=0) # (max_len, 1 + N_ARGS)
|
||||
return loop_vec
|
||||
|
||||
def draw(self, ax):
|
||||
colors = ['red', 'blue', 'green', 'brown', 'pink', 'yellow', 'purple', 'black'] * 10
|
||||
for i, curve in enumerate(self.children):
|
||||
curve.draw(ax, colors[i])
|
||||
|
||||
def sample_points(self, n=32):
|
||||
points = np.stack([curve.sample_points(n) for curve in self.children], axis=0) # (n_curves, n, 2)
|
||||
return points
|
||||
|
||||
|
||||
class Profile(SketchBase):
|
||||
"""Sketch profile,a closed region formed by one or more loops.
|
||||
The outer-most loop is placed at first."""
|
||||
@staticmethod
|
||||
def from_dict(stat):
|
||||
all_loops = [Loop.from_dict(item) for item in stat['loops']]
|
||||
return Profile(all_loops)
|
||||
|
||||
def __str__(self):
|
||||
return "Profile:" + "\n -".join([str(loop) for loop in self.children])
|
||||
|
||||
@staticmethod
|
||||
def from_vector(vec, start_point=None, is_numerical=True):
|
||||
all_loops = []
|
||||
command = vec[:, 0]
|
||||
end_idx = command.tolist().index(EOS_IDX)
|
||||
indices = np.where(command[:end_idx] == SOL_IDX)[0].tolist() + [end_idx]
|
||||
for i in range(len(indices) - 1):
|
||||
loop_vec = vec[indices[i]:indices[i + 1]]
|
||||
loop_vec = np.concatenate([loop_vec, EOS_VEC[np.newaxis]], axis=0)
|
||||
if loop_vec[0][0] == SOL_IDX and loop_vec[1][0] not in [SOL_IDX, EOS_IDX]:
|
||||
all_loops.append(Loop.from_vector(loop_vec, is_numerical=is_numerical))
|
||||
return Profile(all_loops)
|
||||
|
||||
def reorder(self):
|
||||
if len(self.children) <= 1:
|
||||
return
|
||||
all_loops_bbox_min = np.stack([loop.bbox[0] for loop in self.children], axis=0).round(6)
|
||||
ind = np.lexsort(all_loops_bbox_min.transpose()[[1, 0]])
|
||||
self.children = [self.children[i] for i in ind]
|
||||
|
||||
def draw(self, ax):
|
||||
for i, loop in enumerate(self.children):
|
||||
loop.draw(ax)
|
||||
ax.text(loop.start_point[0], loop.start_point[1], str(i))
|
||||
|
||||
def to_vector(self, max_n_loops=None, max_len_loop=None, pad=True):
|
||||
loop_vecs = [loop.to_vector(None, add_eos=False) for loop in self.children]
|
||||
if max_n_loops is not None and len(loop_vecs) > max_n_loops:
|
||||
return None
|
||||
for vec in loop_vecs:
|
||||
if max_len_loop is not None and vec.shape[0] > max_len_loop:
|
||||
return None
|
||||
profile_vec = np.concatenate(loop_vecs, axis=0)
|
||||
profile_vec = np.concatenate([profile_vec, EOS_VEC[np.newaxis]], axis=0)
|
||||
if pad:
|
||||
pad_len = max_n_loops * max_len_loop - profile_vec.shape[0]
|
||||
profile_vec = np.concatenate([profile_vec, EOS_VEC[np.newaxis].repeat(pad_len, axis=0)], axis=0)
|
||||
return profile_vec
|
||||
|
||||
def sample_points(self, n=32):
|
||||
points = np.concatenate([loop.sample_points(n) for loop in self.children], axis=0)
|
||||
return points
|
||||
Reference in New Issue
Block a user