Skip to content

Commit

Permalink
Resolve reshape bug (#86)
Browse files Browse the repository at this point in the history
* fix xpos and xquat reshape order

* fix compute_velocity args

---------

Co-authored-by: Charles Zhang <charleszhang@holylogin06.rc.fas.harvard.edu>
  • Loading branch information
charles-zhng and Charles Zhang authored Feb 12, 2025
1 parent 5d99b42 commit f1957d3
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 46 deletions.
30 changes: 1 addition & 29 deletions stac_mjx/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def load_h5(filename):
Returns:
dict: Dictionary containing the data from the .h5 file.
"""
# TODO add track information
# TODO tracks is a hardcoded dataset name
data = {}
with h5py.File(filename, "r") as f:
for key in f.keys():
Expand Down Expand Up @@ -253,34 +253,6 @@ def _todict(matobj):
return dict


def _load_params(param_path):
"""Load parameters for the animal.
:param param_path: Path to .yaml file specifying animal parameters.
"""
with open(param_path, "r") as infile:
try:
params = yaml.safe_load(infile)
except yaml.YAMLError as exc:
print(exc)
return params


def save_dict_to_hdf5(group, dictionary):
"""Save a dictionary to an HDF5 group.
Args:
group (h5py.Group): HDF5 group to save the dictionary to.
dictionary (dict): Dictionary to save.
"""
for key, value in dictionary.items():
if isinstance(value, dict):
subgroup = group.create_group(key)
save_dict_to_hdf5(subgroup, value)
else:
group.attrs[key] = value


def save_data_to_h5(
config: Config,
kp_names: list,
Expand Down
4 changes: 1 addition & 3 deletions stac_mjx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ def run_stac(
)
if cfg.stac.infer_qvels:
t_vel = time.time()
qvels = vmap_compute_velocity_fn(
qpos_trajectory=batched_qpos, freejoint=stac._freejoint
)
qvels = vmap_compute_velocity_fn(qpos_trajectory=batched_qpos)
# set dict key after reshaping and casting to numpy
ik_only_data.qvel = np.array(qvels).reshape(-1, *qvels.shape[2:])
print(f"Finished compute velocity in {time.time() - t_vel}")
Expand Down
19 changes: 9 additions & 10 deletions stac_mjx/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,10 @@ def _create_body_sites(self, root: mjcf.Element):
else:
is_regularized.append(jp.array([0.0, 0.0, 0.0]))
is_regularized = jp.stack(is_regularized).flatten()

body_site_idxs = jp.array(list(site_index_map.values()))
return (
# physics,
physics.model.ptr,
jp.array(list(site_index_map.values())),
body_site_idxs,
is_regularized,
)

Expand Down Expand Up @@ -421,8 +420,8 @@ def _package_data(
get_batch_offsets = jax.vmap(utils.get_site_pos, in_axes=(0, None))
offsets = get_batch_offsets(mjx_model, self._body_site_idxs)[0]
qposes = qposes.reshape(-1, qposes.shape[-1])
xposes = xposes.reshape(-1, *xposes.shape[2:])
xquats = xquats.reshape(-1, *xquats.shape[2:])
xposes = xposes.reshape(-1, *xposes.shape[2:], order="F")
xquats = xquats.reshape(-1, *xquats.shape[2:], order="F")
marker_sites = marker_sites.reshape(-1, *marker_sites.shape[2:])
else:
offsets = self._offsets.reshape((-1, 3))
Expand All @@ -442,7 +441,7 @@ def _package_data(
kp_names=self._kp_names,
)

def _create_keypoint_sites(self):
def _create_render_sites(self):
"""Create sites for keypoints (used for rendering only).
Returns:
Expand Down Expand Up @@ -480,10 +479,10 @@ def _create_keypoint_sites(self):
site_index_map[n] for n in self.cfg.model.KEYPOINT_MODEL_PAIRS.keys()
]
keypoint_site_idxs = [site_index_map[n] for n in keypoint_site_names]

self._body_site_idxs = body_site_idxs
self._keypoint_site_idxs = keypoint_site_idxs

return deepcopy(physics.model.ptr), body_site_idxs, keypoint_site_idxs
return (deepcopy(physics.model.ptr), body_site_idxs, keypoint_site_idxs)

def render(
self,
Expand Down Expand Up @@ -534,7 +533,7 @@ def render(
)

render_mj_model, body_site_idxs, keypoint_site_idxs = (
self._create_keypoint_sites()
self._create_render_sites()
)

# Add body sites for new offsets
Expand Down Expand Up @@ -596,7 +595,7 @@ def render(
qposes = qposes[start_frame : start_frame + n_frames]

frames = []
# render while stepping using mujoco
# Render while stepping using mujoco
with imageio.get_writer(save_path, fps=self.cfg.model.RENDER_FPS) as video:
for qpos, kps in tqdm(zip(qposes, kp_data)):
# Set keypoints--they're in cartesian space, but since they're attached to the worldbody they're the same as offsets
Expand Down
1 change: 1 addition & 0 deletions stac_mjx/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def viz_stac(
# initialize stac to create mj_model with scaling and marker body sites according to config
# Set the learned offsets for body sites manually
stac = Stac(xml_path, cfg, kp_names)

return cfg, stac.render(
qposes,
kp_data,
Expand Down
2 changes: 0 additions & 2 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ def test_load_nwb(config, mocap_nwb):
"""
Test loading data from .nwb file.
"""
# params = utils._load_params(_BASE_PATH / rodent_config)
# assert params is not None
cfg = load_config_with_overrides(config, stac_data_path_override=mocap_nwb)
data, sorted_kp_names = io.load_mocap(cfg)
assert data.shape == (1000, 69)
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ def test_load_nwb(config, mocap_nwb):
"""
Test loading data from .nwb file.
"""
# params = utils._load_params(_BASE_PATH / rodent_config)
# assert params is not None
cfg = load_config_with_overrides(config, stac_data_path_override=mocap_nwb)
data, sorted_kp_names = io.load_mocap(cfg)
assert data.shape == (1000, 69)
Expand Down

0 comments on commit f1957d3

Please sign in to comment.