Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

basic python api #31

Merged
merged 22 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
lcov.info

.DS_Store
snippets*
# error data files
Expand Down
19 changes: 10 additions & 9 deletions configs/rodent.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Frames per clip for transform.
N_FRAMES_PER_CLIP: 360
N_FRAMES_PER_CLIP: 250

# Tolerance for the optimizations of the full model, limb, and root.
# TODO: Re-implement optimizer loops to use these tolerances
jf514 marked this conversation as resolved.
Show resolved Hide resolved
# FTOL: 5.0e-03
# ROOT_FTOL: 1.0e-05
# LIMB_FTOL: 1.0e-06

# Number of alternating pose and offset optimization rounds.
N_ITERS: 6

KP_NAMES:
- 'Snout'
Expand Down Expand Up @@ -171,14 +180,6 @@ RENDER_FPS: 50

N_SAMPLE_FRAMES: 100

# Tolerance for the optimizations of the full model, limb, and root.
FTOL: 1.0e-02
ROOT_FTOL: 1.0e-05
LIMB_FTOL: 1.0e-06

# Number of alternating pose and offset optimization rounds.
N_ITERS: 6

# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using _SITES_TO_REGULARIZE.
M_REG_COEF: 1
Expand Down
15 changes: 5 additions & 10 deletions configs/stac.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
paths:
model_config: "rodent"
xml: "././models/rodent.xml"
fit_path: "fit_sq.p"
transform_path: "transform_sq.p"
xml: "./models/rodent.xml"
fit_path: "fit.p"
transform_path: "transform.p"
data_path: "./tests/data/test_pred_only_1000_frames.mat"

n_fit_frames: 1000
sampler: "first" # first, every, or random
first_start: 0 # starting frame for "first" sampler

# Should this be included?
test:
skip_fit: False
skip_transform: False
skip_fit: False
skip_transform: True

mujoco:
solver: "newton"
Expand Down
Empty file added conftest.py
Empty file.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- pytest-cov
- glfw3
- ipykernel
- pip:
- pip:
talmo marked this conversation as resolved.
Show resolved Hide resolved
- -r requirements.txt

variables: # Set MuJoCo environment variables
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ hydra-core
imageio
h5py
flax[all]
optax[all]
optax[all] >= 0.2.3
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12_pip]
imageio[pyav]
Expand Down
54 changes: 54 additions & 0 deletions run_rodent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import jax
from jax import numpy as jnp
from jax.lib import xla_bridge
import numpy as np

import os
import logging
import hydra
from omegaconf import DictConfig, OmegaConf

from stac_mjx import main
from stac_mjx import utils


@hydra.main(config_path="./configs", config_name="stac", version_base=None)
def hydra_entry(cfg: DictConfig):
# Initialize configs and convert to dictionaries
global_cfg = hydra.compose(config_name=cfg.paths.model_config)
logging.info(f"cfg: {OmegaConf.to_yaml(cfg)}")
logging.info(f"global_cfg: {OmegaConf.to_yaml(global_cfg)}")
utils.init_params(OmegaConf.to_container(global_cfg, resolve=True))

# XLA flags for Nvidia GPU
if xla_bridge.get_backend().platform == "gpu":
os.environ["XLA_FLAGS"] = (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=True "
)
# Set N_GPUS
utils.params["N_GPUS"] = jax.local_device_count("gpu")

# Set up mocap data
kp_names = utils.params["KP_NAMES"]
# argsort returns the indices that sort the array to match the order of marker sites
stac_keypoint_order = np.argsort(kp_names)
jf514 marked this conversation as resolved.
Show resolved Hide resolved
data_path = cfg.paths.data_path

# Load kp_data, /1000 to scale data (from mm to meters)
kp_data = utils.loadmat(data_path)["pred"][:] / 1000

# Preparing DANNCE data by reordering and reshaping
# Resulting kp_data is of shape (n_frames, n_keypoints)
kp_data = jnp.array(kp_data[:, :, stac_keypoint_order])
kp_data = jnp.transpose(kp_data, (0, 2, 1))
kp_data = jnp.reshape(kp_data, (kp_data.shape[0], -1))

return main.run_stac(cfg, kp_data)


if __name__ == "__main__":
fit_path, transform_path = hydra_entry()
logging.info(
f"Run complete. \n fit path: {fit_path} \n transform path: {transform_path}"
)
43 changes: 6 additions & 37 deletions stac_mjx/compute_stac.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Compute stac optimization on data."""

import jax
from jax import vmap
import jax.numpy as jnp
import stac_base
import operations as op
import utils
from typing import List, Dict, Tuple, Text

from typing import Tuple
import time
import logging

from stac_mjx import stac_base
from stac_mjx import utils
from stac_mjx import operations as op


def root_optimization(mjx_model, mjx_data, kp_data, frame: int = 0):
Expand Down Expand Up @@ -250,34 +250,3 @@ def f(mjx_data, kp_data, n_frame, parts):
jnp.array(frame_time),
jnp.array(frame_error),
)


def package_data(mjx_model, physics, q, x, walker_body_sites, kp_data, batched=False):
"""Extract pose, offsets, data, and all parameters."""
if batched:
# prepare batched data to be packaged
get_batch_offsets = vmap(op.get_site_pos)
offsets = get_batch_offsets(mjx_model).copy()[0]
x = x.reshape(-1, x.shape[-1])
q = q.reshape(-1, q.shape[-1])
else:
offsets = op.get_site_pos(mjx_model).copy()

names_xpos = physics.named.data.xpos.axes.row.names

print(f"shape of qpos: {q.shape}")
kp_data = kp_data.reshape(-1, kp_data.shape[-1])
data = {
"qpos": q,
"xpos": x,
"walker_body_sites": walker_body_sites,
"offsets": offsets,
"names_qpos": utils.params["part_names"],
"names_xpos": names_xpos,
"kp_data": jnp.copy(kp_data),
}

for k, v in utils.params.items():
data[k] = v

return data
61 changes: 43 additions & 18 deletions stac_mjx/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,17 @@
from jax import vmap
from jax import numpy as jnp

import mujoco
from mujoco import mjx

import numpy as np

from typing import Text

from dm_control import mjcf
from dm_control.locomotion.walkers import rescale

import utils
from compute_stac import *
import operations as op

import pickle
import logging
import os
from stac_mjx import utils as utils
from stac_mjx import compute_stac
from stac_mjx import operations as op
from typing import List
from statistics import fmean, pstdev


Expand Down Expand Up @@ -206,12 +200,12 @@ def fit(mj_model, kp_data):
utils.params["ub"] = ub

# Begin optimization steps
mjx_data = root_optimization(mjx_model, mjx_data, kp_data)
mjx_data = compute_stac.root_optimization(mjx_model, mjx_data, kp_data)

for n_iter in range(utils.params["N_ITERS"]):
print(f"Calibration iteration: {n_iter + 1}/{utils.params['N_ITERS']}")
mjx_data, q, walker_body_sites, x, frame_time, frame_error = pose_optimization(
mjx_model, mjx_data, kp_data
mjx_data, q, walker_body_sites, x, frame_time, frame_error = (
compute_stac.pose_optimization(mjx_model, mjx_data, kp_data)
)

for i, (t, e) in enumerate(zip(frame_time, frame_error)):
Expand All @@ -224,14 +218,14 @@ def fit(mj_model, kp_data):
print(f"Standard deviation: {std}")

print("starting offset optimization")
mjx_model, mjx_data = offset_optimization(
mjx_model, mjx_data = compute_stac.offset_optimization(
mjx_model, mjx_data, kp_data, offsets, q
)

# Optimize the pose for the whole sequence
print("Final pose optimization")
mjx_data, q, walker_body_sites, x, frame_time, frame_error = pose_optimization(
mjx_model, mjx_data, kp_data
mjx_data, q, walker_body_sites, x, frame_time, frame_error = (
compute_stac.pose_optimization(mjx_model, mjx_data, kp_data)
)

for i, (t, e) in enumerate(zip(frame_time, frame_error)):
Expand Down Expand Up @@ -289,8 +283,8 @@ def mjx_setup(kp_data, mj_model):
mjx_model, mjx_data = vmap_mjx_setup(kp_data, mj_model)

# Vmap optimize functions
vmap_root_opt = vmap(root_optimization)
vmap_pose_opt = vmap(pose_optimization)
vmap_root_opt = vmap(compute_stac.root_optimization)
vmap_pose_opt = vmap(compute_stac.pose_optimization)

# q_phase
mjx_data = vmap_root_opt(mjx_model, mjx_data, kp_data)
Expand All @@ -305,3 +299,34 @@ def mjx_setup(kp_data, mj_model):
print(f"Standard deviation: {std}")

return mjx_model, q, x, walker_body_sites, kp_data


def package_data(mjx_model, physics, q, x, walker_body_sites, kp_data, batched=False):
"""Extract pose, offsets, data, and all parameters."""
if batched:
# prepare batched data to be packaged
get_batch_offsets = vmap(op.get_site_pos)
offsets = get_batch_offsets(mjx_model).copy()[0]
x = x.reshape(-1, x.shape[-1])
q = q.reshape(-1, q.shape[-1])
else:
offsets = op.get_site_pos(mjx_model).copy()

names_xpos = physics.named.data.xpos.axes.row.names

print(f"shape of qpos: {q.shape}")
charles-zhng marked this conversation as resolved.
Show resolved Hide resolved
kp_data = kp_data.reshape(-1, kp_data.shape[-1])
data = {
"qpos": q,
"xpos": x,
"walker_body_sites": walker_body_sites,
"offsets": offsets,
"names_qpos": utils.params["part_names"],
"names_xpos": names_xpos,
"kp_data": jnp.copy(kp_data),
}

for k, v in utils.params.items():
data[k] = v

return data
Loading