Skip to content

Commit

Permalink
Save output data as h5 (#84)
Browse files Browse the repository at this point in the history
* add compute velocity step to run_stac(), rename op_utils to utils

* fix shapes

* print statements for logging in notebook

* update save

* load h5

* fix reshape bug

* fix saving

* rework config and output io

* update run_stac

* cleanup

* linter

* lint

* Documenting config, check for freejoint in infer_qvel

* fix type

---------

Co-authored-by: Charles Zhang <charleszhang@holylogin05.rc.fas.harvard.edu>
Co-authored-by: Charles Zhang <charleszhang@boslogin07.rc.fas.harvard.edu>
  • Loading branch information
3 people authored Feb 6, 2025
1 parent f3980e4 commit 5d99b42
Show file tree
Hide file tree
Showing 36 changed files with 874 additions and 620 deletions.
6 changes: 0 additions & 6 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,3 @@ defaults:
- stac: demo
- model: rodent
- _self_

##FLY_MODEL
# defaults:
# - stac: stac_fly_tethered
# - model: fly_tethered
# - _self_
3 changes: 0 additions & 3 deletions configs/model/fly_tethered.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# MJCF_PATH: 'models/fruitfly/fruitfly_freeforce.xml'
MJCF_PATH: 'models/fruitfly/fruitfly_force_free.xml'

# Frames per clip for transform.
N_FRAMES_PER_CLIP: 300

# Tolerance for the optimizations of the full model, limb, and root.
# TODO: Re-implement optimizer loops to use these tolerances
FTOL: 5.0e-03
Expand Down
3 changes: 0 additions & 3 deletions configs/model/fly_treadmill.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@

MJCF_PATH: 'models/fruitfly/fruitfly_force_free.xml'

# Frames per clip for transform.
N_FRAMES_PER_CLIP: 581

# Tolerance for the optimizations of the full model, limb, and root.
# TODO: Re-implement optimizer loops to use these tolerances
FTOL: 5.0e-03
Expand Down
3 changes: 0 additions & 3 deletions configs/model/rodent.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
MJCF_PATH: "models/rodent.xml"

# Frames per clip for ik_only.
N_FRAMES_PER_CLIP: 250

# Tolerance for the optimizations of the full model, limb, and root.
# TODO: Re-implement optimizer loops to use these tolerances
FTOL: 5.0e-03
Expand Down
3 changes: 0 additions & 3 deletions configs/model/synth_data.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@

MJCF_PATH: 'models/synth_model.xml'

# Frames per clip for transform.
N_FRAMES_PER_CLIP: 1

# Tolerance for the optimizations of the full model, limb, and root.
# TODO: Re-implement optimizer loops to use these tolerances
FTOL: 5.0e-03
Expand Down
8 changes: 5 additions & 3 deletions configs/stac/demo.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
fit_offsets_path: "demo_fit.p"
ik_only_path: "demo_ik_only.p"
fit_offsets_path: "demo_fit_offsets.h5"
ik_only_path: "demo_ik_only.h5"
data_path: "tests/data/test_rodent_mocap_1000_frames.mat"

n_fit_frames: 1
n_fit_frames: 10
skip_fit_offsets: False
skip_ik_only: True
infer_qvels: True # Infer qvels from stac output
n_frames_per_clip: 250 # 1 if mocap is one session

mujoco:
solver: "newton"
Expand Down
11 changes: 2 additions & 9 deletions configs/stac/stac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,8 @@ data_path: "tests/data/test_rodent_mocap_1000_frames.nwb"
n_fit_frames: 1000
skip_fit_offsets: False
skip_ik_only: True

##FLY_MODEL
# fit_path: "fit_tethered.p"
# transform_path: "transform_tethered.p"
# data_path: "tests/data/test_rodent_mocap_1000_frames.nwb"
#
#n_fit_frames: 601
#skip_fit_offsets: False
#skip_ik_only: False
infer_qvels: True
n_frames_per_clip: 250

mujoco:
solver: "newton"
Expand Down
2 changes: 2 additions & 0 deletions configs/stac/stac_fly_tethered.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ data_path: "tests/data/bout_dict.h5"
gpu: '0'

n_fit_frames: 601
n_frames_per_clip: 300

skip_fit: True
skip_transform: False

Expand Down
2 changes: 2 additions & 0 deletions configs/stac/stac_fly_treadmill.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ ik_only_path: "transform_treadmill.p"
data_path: "../tests/data/wt_berlin_linear_treadmill_dataset.csv"

n_fit_frames: 1800
n_frames_per_clip: 581

skip_fit: False
skip_transform: False
gpu: '1'
Expand Down
2 changes: 2 additions & 0 deletions configs/stac/stac_mouse.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ data_path: "tests/data/test_mouse_mocap_3600_frames.h5"
n_fit_frames: 250
skip_fit_offsets: False
skip_ik_only: True
infer_qvels: True
n_frames_per_clip: 360

mujoco:
solver: "newton"
Expand Down
8 changes: 5 additions & 3 deletions configs/stac/stac_synth_data.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
fit_offsets_path: "synth_fit.p"
ik_only_path: "synth_ik_only.p"
fit_offsets_path: "synth_fit.h5"
ik_only_path: "synth_ik_only.h5"
data_path: "tests/data/test_synth_1_frames.nwb"

n_fit_frames: 1
skip_fit_offsets: False
skip_ik_only: False
skip_ik_only: True
infer_qvels: False # Infer qvels from stac output
n_frames_per_clip: 1

mujoco:
solver: newton
Expand Down
270 changes: 145 additions & 125 deletions demos/api_usage.ipynb

Large diffs are not rendered by default.

30 changes: 12 additions & 18 deletions demos/viz_usage.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion run_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


def load_and_run_stac(cfg):
kp_data, sorted_kp_names = stac_mjx.load_data(cfg)
kp_data, sorted_kp_names = stac_mjx.load_mocap(cfg)

fit_path, ik_only_path = stac_mjx.run_stac(cfg, kp_data, sorted_kp_names)

Expand Down
4 changes: 2 additions & 2 deletions stac_mjx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""This module exposes all high level APIs for stac-mjx."""

from stac_mjx.op_utils import enable_xla_flags
from stac_mjx.io import load_data
from stac_mjx.utils import enable_xla_flags
from stac_mjx.io import load_mocap
from stac_mjx.main import load_configs, run_stac
from stac_mjx.viz import viz_stac
57 changes: 26 additions & 31 deletions stac_mjx/compute_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import jax
import jax.numpy as jp

import numpy as np
from typing import Tuple, List
import time

from stac_mjx import stac_core
from stac_mjx import op_utils
from stac_mjx import utils


def root_optimization(
Expand Down Expand Up @@ -46,14 +46,6 @@ def root_optimization(
s = time.time()
q0 = jp.copy(mjx_data.qpos[:])

# Set the root_kp_index below according to a keypoint in the
# KEYPOINT_MODEL_PAIRS that is near the center of the model, not
# necessarily exactly so. The value of 3*18 is chosen for the
# rodent.xml, corresponding to the index of 'SpineL' keypoint.
# For the mouse model this should be 3*5, corresponding 'Trunk'
# root_kp_idx = 3 * 18
# FLY_MODEL:
# root_kp_idx = 0
q0.at[:3].set(kp_data[frame, :][root_kp_idx : root_kp_idx + 3])
qs_to_opt = jp.zeros_like(q0, dtype=bool)
qs_to_opt = qs_to_opt.at[:7].set(True)
Expand All @@ -75,8 +67,8 @@ def root_optimization(

r = time.time()

mjx_data = op_utils.replace_qs(
mjx_model, mjx_data, op_utils.make_qs(q0, qs_to_opt, res.params)
mjx_data = utils.replace_qs(
mjx_model, mjx_data, utils.make_qs(q0, qs_to_opt, res.params)
)
print(f"Replace 1 finished in {time.time()-r}")

Expand All @@ -101,8 +93,8 @@ def root_optimization(
print(f"q_opt 2 finished in {time.time()-j} with an error of {res.state.error}")
r = time.time()

mjx_data = op_utils.replace_qs(
mjx_model, mjx_data, op_utils.make_qs(q0, qs_to_opt, res.params)
mjx_data = utils.replace_qs(
mjx_model, mjx_data, utils.make_qs(q0, qs_to_opt, res.params)
)

print(f"Replace 2 finished in {time.time()-r}")
Expand Down Expand Up @@ -149,7 +141,7 @@ def offset_optimization(
print("Begining offset optimization:")

# Define initial position of the optimization
offset0 = op_utils.get_site_pos(mjx_model, site_idxs).flatten()
offset0 = utils.get_site_pos(mjx_model, site_idxs).flatten()

keypoints = jp.array(kp_data[time_indices, :])
q = jp.take(q, time_indices, axis=0)
Expand All @@ -170,12 +162,12 @@ def offset_optimization(
print(f"Final error of {res.state.error}")

# Set body sites according to optimized offsets
mjx_model = op_utils.set_site_pos(
mjx_model = utils.set_site_pos(
mjx_model, jp.reshape(offset_opt_param, (-1, 3)), site_idxs
)

# Forward kinematics, and save the results to the walker sites as well
mjx_data = op_utils.kinematics(mjx_model, mjx_data)
mjx_data = utils.kinematics(mjx_model, mjx_data)

print(f"offset optimization finished in {time.time()-s}")

Expand Down Expand Up @@ -206,9 +198,10 @@ def pose_optimization(
Tuple: Updated mjx.Data, optimized qpos, offset site xpos, mjx.Data.xpos for each frame, and info for logging (optimization time and errors)
"""
s = time.time()
q = []
x = []
walker_body_sites = []
qposes = []
xposes = []
xquats = []
marker_sites = []

# Iterate through all of the frames
frames = jp.arange(kp_data.shape[0])
Expand All @@ -233,7 +226,7 @@ def f(mjx_data, kp_data, n_frame, parts):
site_idxs,
)

mjx_data = op_utils.replace_qs(mjx_model, mjx_data, res.params)
mjx_data = utils.replace_qs(mjx_model, mjx_data, res.params)

for part in parts:
q0 = jp.copy(mjx_data.qpos[:])
Expand All @@ -250,8 +243,8 @@ def f(mjx_data, kp_data, n_frame, parts):
site_idxs,
)

mjx_data = op_utils.replace_qs(
mjx_model, mjx_data, op_utils.make_qs(q0, part, res.params)
mjx_data = utils.replace_qs(
mjx_model, mjx_data, utils.make_qs(q0, part, res.params)
)

return mjx_data, res.state.error
Expand All @@ -264,19 +257,21 @@ def f(mjx_data, kp_data, n_frame, parts):

mjx_data, error = f(mjx_data, kp_data, n_frame, indiv_parts)

q.append(mjx_data.qpos[:])
x.append(mjx_data.xpos[:])
walker_body_sites.append(op_utils.get_site_xpos(mjx_data, site_idxs))
qposes.append(mjx_data.qpos[:])
xposes.append(mjx_data.xpos[:])
xquats.append(mjx_data.xquat[:])
marker_sites.append(utils.get_site_xpos(mjx_data, site_idxs))

frame_time.append(time.time() - loop_start)
frame_error.append(error)

print(f"Pose Optimization done in {time.time()-s}")
return (
mjx_data,
jp.array(q),
jp.array(walker_body_sites),
jp.array(x),
jp.array(frame_time),
jp.array(frame_error),
jp.array(qposes),
xposes,
xquats,
marker_sites,
frame_time,
frame_error,
)
Loading

0 comments on commit 5d99b42

Please sign in to comment.