From f1957d363089f972607a480c99623512fa515c0a Mon Sep 17 00:00:00 2001 From: Charles Zhang <33401293+charles-zhng@users.noreply.github.com> Date: Tue, 11 Feb 2025 21:26:22 -0500 Subject: [PATCH] Resolve reshape bug (#86) * fix xpos and xquat reshape order * fix compute_velocity args --------- Co-authored-by: Charles Zhang --- stac_mjx/io.py | 30 +----------------------------- stac_mjx/main.py | 4 +--- stac_mjx/stac.py | 19 +++++++++---------- stac_mjx/viz.py | 1 + tests/test_io.py | 2 -- tests/unit/test_utils.py | 2 -- 6 files changed, 12 insertions(+), 46 deletions(-) diff --git a/stac_mjx/io.py b/stac_mjx/io.py index a4709a2..1f3858d 100755 --- a/stac_mjx/io.py +++ b/stac_mjx/io.py @@ -218,7 +218,7 @@ def load_h5(filename): Returns: dict: Dictionary containing the data from the .h5 file. """ - # TODO add track information + # TODO tracks is a hardcoded dataset name data = {} with h5py.File(filename, "r") as f: for key in f.keys(): @@ -253,34 +253,6 @@ def _todict(matobj): return dict -def _load_params(param_path): - """Load parameters for the animal. - - :param param_path: Path to .yaml file specifying animal parameters. - """ - with open(param_path, "r") as infile: - try: - params = yaml.safe_load(infile) - except yaml.YAMLError as exc: - print(exc) - return params - - -def save_dict_to_hdf5(group, dictionary): - """Save a dictionary to an HDF5 group. - - Args: - group (h5py.Group): HDF5 group to save the dictionary to. - dictionary (dict): Dictionary to save. - """ - for key, value in dictionary.items(): - if isinstance(value, dict): - subgroup = group.create_group(key) - save_dict_to_hdf5(subgroup, value) - else: - group.attrs[key] = value - - def save_data_to_h5( config: Config, kp_names: list, diff --git a/stac_mjx/main.py b/stac_mjx/main.py index 0660415..5ccc584 100644 --- a/stac_mjx/main.py +++ b/stac_mjx/main.py @@ -115,9 +115,7 @@ def run_stac( ) if cfg.stac.infer_qvels: t_vel = time.time() - qvels = vmap_compute_velocity_fn( - qpos_trajectory=batched_qpos, freejoint=stac._freejoint - ) + qvels = vmap_compute_velocity_fn(qpos_trajectory=batched_qpos) # set dict key after reshaping and casting to numpy ik_only_data.qvel = np.array(qvels).reshape(-1, *qvels.shape[2:]) print(f"Finished compute velocity in {time.time() - t_vel}") diff --git a/stac_mjx/stac.py b/stac_mjx/stac.py index 201c994..12d4f64 100644 --- a/stac_mjx/stac.py +++ b/stac_mjx/stac.py @@ -175,11 +175,10 @@ def _create_body_sites(self, root: mjcf.Element): else: is_regularized.append(jp.array([0.0, 0.0, 0.0])) is_regularized = jp.stack(is_regularized).flatten() - + body_site_idxs = jp.array(list(site_index_map.values())) return ( - # physics, physics.model.ptr, - jp.array(list(site_index_map.values())), + body_site_idxs, is_regularized, ) @@ -421,8 +420,8 @@ def _package_data( get_batch_offsets = jax.vmap(utils.get_site_pos, in_axes=(0, None)) offsets = get_batch_offsets(mjx_model, self._body_site_idxs)[0] qposes = qposes.reshape(-1, qposes.shape[-1]) - xposes = xposes.reshape(-1, *xposes.shape[2:]) - xquats = xquats.reshape(-1, *xquats.shape[2:]) + xposes = xposes.reshape(-1, *xposes.shape[2:], order="F") + xquats = xquats.reshape(-1, *xquats.shape[2:], order="F") marker_sites = marker_sites.reshape(-1, *marker_sites.shape[2:]) else: offsets = self._offsets.reshape((-1, 3)) @@ -442,7 +441,7 @@ def _package_data( kp_names=self._kp_names, ) - def _create_keypoint_sites(self): + def _create_render_sites(self): """Create sites for keypoints (used for rendering only). Returns: @@ -480,10 +479,10 @@ def _create_keypoint_sites(self): site_index_map[n] for n in self.cfg.model.KEYPOINT_MODEL_PAIRS.keys() ] keypoint_site_idxs = [site_index_map[n] for n in keypoint_site_names] + self._body_site_idxs = body_site_idxs self._keypoint_site_idxs = keypoint_site_idxs - - return deepcopy(physics.model.ptr), body_site_idxs, keypoint_site_idxs + return (deepcopy(physics.model.ptr), body_site_idxs, keypoint_site_idxs) def render( self, @@ -534,7 +533,7 @@ def render( ) render_mj_model, body_site_idxs, keypoint_site_idxs = ( - self._create_keypoint_sites() + self._create_render_sites() ) # Add body sites for new offsets @@ -596,7 +595,7 @@ def render( qposes = qposes[start_frame : start_frame + n_frames] frames = [] - # render while stepping using mujoco + # Render while stepping using mujoco with imageio.get_writer(save_path, fps=self.cfg.model.RENDER_FPS) as video: for qpos, kps in tqdm(zip(qposes, kp_data)): # Set keypoints--they're in cartesian space, but since they're attached to the worldbody they're the same as offsets diff --git a/stac_mjx/viz.py b/stac_mjx/viz.py index d305093..9c451c9 100644 --- a/stac_mjx/viz.py +++ b/stac_mjx/viz.py @@ -51,6 +51,7 @@ def viz_stac( # initialize stac to create mj_model with scaling and marker body sites according to config # Set the learned offsets for body sites manually stac = Stac(xml_path, cfg, kp_names) + return cfg, stac.render( qposes, kp_data, diff --git a/tests/test_io.py b/tests/test_io.py index e3a0d72..164fe9c 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -28,8 +28,6 @@ def test_load_nwb(config, mocap_nwb): """ Test loading data from .nwb file. """ - # params = utils._load_params(_BASE_PATH / rodent_config) - # assert params is not None cfg = load_config_with_overrides(config, stac_data_path_override=mocap_nwb) data, sorted_kp_names = io.load_mocap(cfg) assert data.shape == (1000, 69) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index e3a0d72..164fe9c 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -28,8 +28,6 @@ def test_load_nwb(config, mocap_nwb): """ Test loading data from .nwb file. """ - # params = utils._load_params(_BASE_PATH / rodent_config) - # assert params is not None cfg = load_config_with_overrides(config, stac_data_path_override=mocap_nwb) data, sorted_kp_names = io.load_mocap(cfg) assert data.shape == (1000, 69)