Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

basic python api #31

Merged
merged 22 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions configs/rodent.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Frames per clip for transform.
N_FRAMES_PER_CLIP: 360
N_FRAMES_PER_CLIP: 250

# Tolerance for the optimizations of the full model, limb, and root.
FTOL: 5.0e-03
ROOT_FTOL: 1.0e-05
LIMB_FTOL: 1.0e-06

# Number of alternating pose and offset optimization rounds.
N_ITERS: 6

KP_NAMES:
- 'Snout'
Expand Down Expand Up @@ -171,14 +179,6 @@ RENDER_FPS: 50

N_SAMPLE_FRAMES: 100

# Tolerance for the optimizations of the full model, limb, and root.
FTOL: 1.0e-02
ROOT_FTOL: 1.0e-05
LIMB_FTOL: 1.0e-06

# Number of alternating pose and offset optimization rounds.
N_ITERS: 6

# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using _SITES_TO_REGULARIZE.
M_REG_COEF: 1
Expand Down
15 changes: 5 additions & 10 deletions configs/stac.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
paths:
model_config: "rodent"
xml: "././models/rodent.xml"
fit_path: "fit_sq.p"
transform_path: "transform_sq.p"
xml: "./models/rodent.xml"
fit_path: "fit.p"
transform_path: "transform.p"
data_path: "./tests/data/test_pred_only_1000_frames.mat"

n_fit_frames: 1000
sampler: "first" # first, every, or random
first_start: 0 # starting frame for "first" sampler

# Should this be included?
test:
skip_fit: False
skip_transform: False
skip_fit: False
skip_transform: True

mujoco:
solver: "newton"
Expand Down
22 changes: 22 additions & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: stac-mjx-env
channels:
- conda-forge
- anaconda
- menpo
dependencies:
- python=3.11
- numpy<2.0
- matplotlib
- pandas
- pip
- glew
- mesalib
- mesa-libgl-cos6-x86_64
- glfw3
- ipykernel
- pip:
jf514 marked this conversation as resolved.
Show resolved Hide resolved
- -r requirements.txt
jf514 marked this conversation as resolved.
Show resolved Hide resolved

variables: # Set MuJoCo environment variables
MUJOCO_GL: osmesa
PYOPENGL_PLATFORM: osmesa
jf514 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ hydra-core
imageio
h5py
flax[all]
optax[all]
optax[all] >= 0.2.3
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12_pip]
imageio[pyav]
Expand Down
54 changes: 54 additions & 0 deletions run_rodent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import jax
from jax import numpy as jnp
from jax.lib import xla_bridge
import numpy as np

import os
import logging
import hydra
from omegaconf import DictConfig, OmegaConf

from stac_mjx import main
from stac_mjx import utils


@hydra.main(config_path="./configs", config_name="stac", version_base=None)
def hydra_entry(cfg: DictConfig):
# Initialize configs and convert to dictionaries
global_cfg = hydra.compose(config_name=cfg.paths.model_config)
logging.info(f"cfg: {OmegaConf.to_yaml(cfg)}")
logging.info(f"global_cfg: {OmegaConf.to_yaml(global_cfg)}")
utils.init_params(OmegaConf.to_container(global_cfg, resolve=True))

# XLA flags for Nvidia GPU
if xla_bridge.get_backend().platform == "gpu":
os.environ["XLA_FLAGS"] = (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=True "
)
# Set N_GPUS
utils.params["N_GPUS"] = jax.local_device_count("gpu")

# Set up mocap data
kp_names = utils.params["KP_NAMES"]
# argsort returns the indices that sort the array to match the order of marker sites
stac_keypoint_order = np.argsort(kp_names)
jf514 marked this conversation as resolved.
Show resolved Hide resolved
data_path = cfg.paths.data_path

# Load kp_data, /1000 to scale data (from mm to meters)
kp_data = utils.loadmat(data_path)["pred"][:] / 1000

# Preparing data by reordering and reshaping (TODO: will this stay the same?)
# Resulting kp_data is of shape (n_frames, n_keypoints)
kp_data = jnp.array(kp_data[:, :, stac_keypoint_order])
kp_data = jnp.transpose(kp_data, (0, 2, 1))
kp_data = jnp.reshape(kp_data, (kp_data.shape[0], -1))

return main.run_stac(cfg, kp_data)
jf514 marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
fit_path, transform_path = hydra_entry()
logging.info(
f"Run complete. \n fit path: {fit_path} \n transform path: {transform_path}"
)
36 changes: 36 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Setup file for stac."""
from setuptools import setup, find_packages

setup(
name="stac-mjx",
version="0.0.1",
python_requires=">=3.11",
packages=find_packages(),
install_requires=[
"six >= 1.12.0",
"clize >= 4.0.3",
"absl-py >= 0.7.1",
"mujoco-mjx >= 3.1.5",
"dm_control",
"jaxopt",
"flax",
"enum34",
"future",
"lxml",
"mediapy",
"numpy < 2.0",
"pyopengl",
"pyparsing",
"h5py >= 2.9.0",
"scipy >= 1.2.1",
"pyyaml",
"opencv-python",
"imageio",
"matplotlib",
"hydra-core",
"optax",
"colorama",
"imageio[pyav]",
"imageio[ffmpeg]",
],
)
2 changes: 1 addition & 1 deletion stac_mjx/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""This module exposes all high level APIs for stac-mjx."""
"""init file."""
43 changes: 6 additions & 37 deletions stac_mjx/compute_stac.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Compute stac optimization on data."""

import jax
from jax import vmap
import jax.numpy as jnp
import stac_base
import operations as op
import utils
from typing import List, Dict, Tuple, Text

from typing import Tuple

Check warning on line 6 in stac_mjx/compute_stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/compute_stac.py#L6

Added line #L6 was not covered by tests
import time
import logging

from stac_mjx import stac_base
from stac_mjx import utils
from stac_mjx import operations as op

Check warning on line 11 in stac_mjx/compute_stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/compute_stac.py#L9-L11

Added lines #L9 - L11 were not covered by tests


def root_optimization(mjx_model, mjx_data, kp_data, frame: int = 0):
Expand Down Expand Up @@ -250,34 +250,3 @@
jnp.array(frame_time),
jnp.array(frame_error),
)


def package_data(mjx_model, physics, q, x, walker_body_sites, kp_data, batched=False):
"""Extract pose, offsets, data, and all parameters."""
if batched:
# prepare batched data to be packaged
get_batch_offsets = vmap(op.get_site_pos)
offsets = get_batch_offsets(mjx_model).copy()[0]
x = x.reshape(-1, x.shape[-1])
q = q.reshape(-1, q.shape[-1])
else:
offsets = op.get_site_pos(mjx_model).copy()

names_xpos = physics.named.data.xpos.axes.row.names

print(f"shape of qpos: {q.shape}")
kp_data = kp_data.reshape(-1, kp_data.shape[-1])
data = {
"qpos": q,
"xpos": x,
"walker_body_sites": walker_body_sites,
"offsets": offsets,
"names_qpos": utils.params["part_names"],
"names_xpos": names_xpos,
"kp_data": jnp.copy(kp_data),
}

for k, v in utils.params.items():
data[k] = v

return data
61 changes: 43 additions & 18 deletions stac_mjx/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,17 @@
from jax import vmap
from jax import numpy as jnp

import mujoco
from mujoco import mjx

import numpy as np

from typing import Text

from dm_control import mjcf
from dm_control.locomotion.walkers import rescale

import utils
from compute_stac import *
import operations as op

import pickle
import logging
import os
from stac_mjx import utils as utils
from stac_mjx import compute_stac
from stac_mjx import operations as op
from typing import List

Check warning on line 16 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L13-L16

Added lines #L13 - L16 were not covered by tests
from statistics import fmean, pstdev


Expand Down Expand Up @@ -206,12 +200,12 @@
utils.params["ub"] = ub

# Begin optimization steps
mjx_data = root_optimization(mjx_model, mjx_data, kp_data)
mjx_data = compute_stac.root_optimization(mjx_model, mjx_data, kp_data)

Check warning on line 203 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L203

Added line #L203 was not covered by tests

for n_iter in range(utils.params["N_ITERS"]):
print(f"Calibration iteration: {n_iter + 1}/{utils.params['N_ITERS']}")
mjx_data, q, walker_body_sites, x, frame_time, frame_error = pose_optimization(
mjx_model, mjx_data, kp_data
mjx_data, q, walker_body_sites, x, frame_time, frame_error = (

Check warning on line 207 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L207

Added line #L207 was not covered by tests
compute_stac.pose_optimization(mjx_model, mjx_data, kp_data)
)

for i, (t, e) in enumerate(zip(frame_time, frame_error)):
Expand All @@ -224,14 +218,14 @@
print(f"Standard deviation: {std}")

print("starting offset optimization")
mjx_model, mjx_data = offset_optimization(
mjx_model, mjx_data = compute_stac.offset_optimization(

Check warning on line 221 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L221

Added line #L221 was not covered by tests
mjx_model, mjx_data, kp_data, offsets, q
)

# Optimize the pose for the whole sequence
print("Final pose optimization")
mjx_data, q, walker_body_sites, x, frame_time, frame_error = pose_optimization(
mjx_model, mjx_data, kp_data
mjx_data, q, walker_body_sites, x, frame_time, frame_error = (

Check warning on line 227 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L227

Added line #L227 was not covered by tests
compute_stac.pose_optimization(mjx_model, mjx_data, kp_data)
)

for i, (t, e) in enumerate(zip(frame_time, frame_error)):
Expand Down Expand Up @@ -289,8 +283,8 @@
mjx_model, mjx_data = vmap_mjx_setup(kp_data, mj_model)

# Vmap optimize functions
vmap_root_opt = vmap(root_optimization)
vmap_pose_opt = vmap(pose_optimization)
vmap_root_opt = vmap(compute_stac.root_optimization)
vmap_pose_opt = vmap(compute_stac.pose_optimization)

Check warning on line 287 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L286-L287

Added lines #L286 - L287 were not covered by tests

# q_phase
mjx_data = vmap_root_opt(mjx_model, mjx_data, kp_data)
Expand All @@ -305,3 +299,34 @@
print(f"Standard deviation: {std}")

return mjx_model, q, x, walker_body_sites, kp_data


def package_data(mjx_model, physics, q, x, walker_body_sites, kp_data, batched=False):

Check warning on line 304 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L304

Added line #L304 was not covered by tests
"""Extract pose, offsets, data, and all parameters."""
if batched:

Check warning on line 306 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L306

Added line #L306 was not covered by tests
# prepare batched data to be packaged
get_batch_offsets = vmap(op.get_site_pos)
offsets = get_batch_offsets(mjx_model).copy()[0]
x = x.reshape(-1, x.shape[-1])
q = q.reshape(-1, q.shape[-1])

Check warning on line 311 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L308-L311

Added lines #L308 - L311 were not covered by tests
else:
offsets = op.get_site_pos(mjx_model).copy()

Check warning on line 313 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L313

Added line #L313 was not covered by tests

names_xpos = physics.named.data.xpos.axes.row.names

Check warning on line 315 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L315

Added line #L315 was not covered by tests

print(f"shape of qpos: {q.shape}")
charles-zhng marked this conversation as resolved.
Show resolved Hide resolved
kp_data = kp_data.reshape(-1, kp_data.shape[-1])
data = {

Check warning on line 319 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L317-L319

Added lines #L317 - L319 were not covered by tests
"qpos": q,
"xpos": x,
"walker_body_sites": walker_body_sites,
"offsets": offsets,
"names_qpos": utils.params["part_names"],
"names_xpos": names_xpos,
"kp_data": jnp.copy(kp_data),
}

for k, v in utils.params.items():
data[k] = v

Check warning on line 330 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L329-L330

Added lines #L329 - L330 were not covered by tests

return data

Check warning on line 332 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L332

Added line #L332 was not covered by tests
Loading
Loading