Skip to content

Commit

Permalink
move obj meshes out of cache, oishape now support both preload and ru…
Browse files Browse the repository at this point in the history
…ntime load obj
  • Loading branch information
lixiny committed Oct 19, 2023
1 parent efbffca commit ff42015
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 49 deletions.
2 changes: 1 addition & 1 deletion oikit/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.1.0"
__version__ = "1.2.0"
108 changes: 61 additions & 47 deletions oikit/oi_shape/oi_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,30 @@

class OakInkShape:

def __init__(self,
data_split=ALL_SPLIT,
intent_mode=list(ALL_INTENT),
category=ALL_CAT,
mano_assets_root="assets/mano_v1_2",
use_cache=True,
use_downsample_mesh=False):
def __init__(
self,
data_split=ALL_SPLIT,
intent_mode=list(ALL_INTENT),
category=ALL_CAT,
mano_assets_root="assets/mano_v1_2",
use_cache=True,
use_downsample_mesh=False,
preload_obj=False,
):
self.name = "OakInkShape"
self.use_downsample_mesh = use_downsample_mesh
self.use_cache = use_cache
self.preload_obj = preload_obj

assert 'OAKINK_DIR' in os.environ, "environment variable 'OAKINK_DIR' is not set"
data_dir = os.path.join(os.environ['OAKINK_DIR'], "shape")
oi_shape_dir = os.path.join(data_dir, "oakink_shape_v2")
meta_dir = os.path.join(data_dir, "metaV2")

self.data_dir = data_dir
self.meta_dir = meta_dir
self.oi_shape_dir = oi_shape_dir

if data_split == 'all':
data_split = ALL_SPLIT
if category == 'all':
Expand Down Expand Up @@ -72,19 +82,36 @@ def __init__(self,
with open(cache_path, "rb") as p_f:
cache = pickle.load(p_f)
self.grasp_list = cache["grasp_list"]
self.obj_warehouse = cache["obj_warehouse"]
self.data_dir = data_dir
self.meta_dir = meta_dir
return
else:
self.grasp_list = self._prepare_data()
cache = {"grasp_list": self.grasp_list}
with open(cache_path, "wb") as f:
pickle.dump(cache, f)
print(f"{self.name} cache saved to {cache_path}")
else:
self.grasp_list = self._prepare_data()

# * >>>> create obj warehouse
self.obj_warehouse = {}
self.obj_id_set = {g["obj_id"] for g in self.grasp_list}
if preload_obj is True:
suppress_trimesh_logging()
for oid in tqdm(self.obj_id_set, desc="oikit preLoad obj model"):
obj_path = get_obj_path(oid, data_dir, meta_dir, use_downsample=use_downsample_mesh)
obj_trimesh = trimesh.load(obj_path, process=False, force="mesh", skip_materials=True)
bbox_center = (obj_trimesh.vertices.min(0) + obj_trimesh.vertices.max(0)) / 2
obj_trimesh.vertices = obj_trimesh.vertices - bbox_center
self.obj_warehouse[oid] = obj_trimesh

# * >>>> filter with regex
def _prepare_data(self):
# region ===== filter with regex >>>>>
grasp_list = []
category_begin_idx = []
seq_cat_matcher = re.compile(r"(.+)/(.{6})_(.{4})_([_0-9]+)/([\-0-9]+)")
for cat in tqdm(self.categories, desc="Process categories"):
real_matcher = re.compile(rf"({cat}/(.{{6}})/.{{10}})/hand_param\.pkl$")
virtual_matcher = re.compile(rf"({cat}/(.{{6}})/.{{10}})/(.{{6}})/hand_param\.pkl$")
path = os.path.join(oi_shape_dir, cat)
path = os.path.join(self.oi_shape_dir, cat)
category_begin_idx.append(len(grasp_list))
for cur, dirs, files in os.walk(path, followlinks=False):
dirs.sort()
Expand All @@ -95,15 +122,11 @@ def __init__(self,
if len(re_match) > 0:
# ? regex should return : [(path, raw_oid, tag, [oid])]
assert len(re_match) == 1, "regex should return only one match"
source = open(os.path.join(oi_shape_dir, re_match[0][0], "source.txt")).read()
source = open(os.path.join(self.oi_shape_dir, re_match[0][0], "source.txt")).read()
grasp_cat_match = seq_cat_matcher.findall(source)[0]
pass_stage, raw_obj_id, action_id, subject_id, seq_ts = (
grasp_cat_match[0],
grasp_cat_match[1],
grasp_cat_match[2],
grasp_cat_match[3],
grasp_cat_match[4],
)
pass_stage, raw_obj_id, action_id, subject_id, seq_ts = (grasp_cat_match[0], grasp_cat_match[1],
grasp_cat_match[2], grasp_cat_match[3],
grasp_cat_match[4])
obj_id = re_match[0][2] if is_virtual else re_match[0][1]
assert (is_virtual and raw_obj_id == re_match[0][1]) or obj_id == raw_obj_id
# * filter with intent mode
Expand Down Expand Up @@ -147,9 +170,9 @@ def __init__(self,
"alt_grasp_item": None,
}
grasp_list.append(grasp_item)
# * <<<<
# endregion <<<<

# * >>>> cal hand joints
# region ===== cal hand joints >>>>>
batch_hand_pose = []
batch_hand_shape = []
batch_hand_tsl = []
Expand All @@ -168,9 +191,9 @@ def __init__(self,
grasp_list[i]["joints"] = batch_hand_joints[i]
grasp_list[i]["verts"] = batch_hand_verts[i]
grasp_list[i]["hand_tsl"] = batch_hand_tsl[i]
# * <<<<
# endregion <<<<<

# * >>>> handle handover
# region ===== deal with handle handover >>>>>
if "handover" in self.intent_mode:
for i, g in tqdm(enumerate(grasp_list), total=len(grasp_list), desc="Process handover grasp"):
if g["subject_alt_id"] is None:
Expand All @@ -193,35 +216,26 @@ def __init__(self,
}
break
grasp_list = list(filter(lambda x: x["action_id"] != "0004" or x["alt_grasp_item"] is not None, grasp_list))
# * <<<<

# * >>>> create obj warehouse
suppress_trimesh_logging()
self.obj_warehouse = {}
obj_id_set = {g["obj_id"] for g in grasp_list}
for oid in tqdm(obj_id_set, desc="Load obj model"):
obj_path = get_obj_path(oid, data_dir, meta_dir, use_downsample=use_downsample_mesh)
obj_trimesh = trimesh.load(obj_path, process=False, force="mesh", skip_materials=True)
bbox_center = (obj_trimesh.vertices.min(0) + obj_trimesh.vertices.max(0)) / 2
obj_trimesh.vertices = obj_trimesh.vertices - bbox_center
self.obj_warehouse[oid] = obj_trimesh

self.grasp_list = grasp_list
self.data_dir = data_dir
self.meta_dir = meta_dir
# endregion <<<<<

if use_cache is True:
cache = {"grasp_list": self.grasp_list, "obj_warehouse": self.obj_warehouse}
with open(cache_path, "wb") as f:
pickle.dump(cache, f)
print(f"{self.name} cache saved to {cache_path}")
return grasp_list

def __len__(self):
return len(self.grasp_list)

def get_obj_mesh(self, idx):
obj_id = self.grasp_list[idx]["obj_id"]
if obj_id not in self.obj_warehouse:
obj_path = get_obj_path(obj_id, self.data_dir, self.meta_dir, use_downsample=self.use_downsample_mesh)
obj_trimesh = trimesh.load(obj_path, process=False, force="mesh", skip_materials=True)
bbox_center = (obj_trimesh.vertices.min(0) + obj_trimesh.vertices.max(0)) / 2
obj_trimesh.vertices = obj_trimesh.vertices - bbox_center
self.obj_warehouse[obj_id] = obj_trimesh
return self.obj_warehouse[obj_id]

def __getitem__(self, idx):
grasp = self.grasp_list[idx]
obj_mesh = self.obj_warehouse[grasp["obj_id"]]
obj_mesh = self.get_obj_mesh(idx)
grasp["obj_verts"] = obj_mesh.vertices.astype(np.float32)
grasp["obj_faces"] = obj_mesh.faces.astype(np.int32)
if grasp["action_id"] == "0004":
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_dep():

setup(
name="oikit",
version="1.1.0",
version="1.2.0",
author="Lixin Yang",
author_email="siriusyang@sjtu.edu.cn",
description="OakInk tooKIT",
Expand Down

0 comments on commit ff42015

Please sign in to comment.