Skip to content

Commit

Permalink
support ball joints (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-zhng authored Sep 11, 2024
1 parent 961483f commit 7fef98f
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 27 deletions.
73 changes: 47 additions & 26 deletions stac_mjx/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,37 @@
import imageio
from tqdm import tqdm

# Root position (3) + quaternion (7) in qpos
_ROOT_QPOS_LB = -jp.inf * jp.ones(7)
_ROOT_QPOS_UB = jp.inf * jp.ones(7)
# Root = position (3) + quaternion (4)
_ROOT_QPOS_LB = jp.concatenate([-jp.inf * jp.ones(3), -1.0 * jp.ones(4)])
_ROOT_QPOS_UB = jp.concatenate([jp.inf * jp.ones(3), 1.0 * jp.ones(4)])

# mujoco jnt_type enums: https://mujoco.readthedocs.io/en/latest/APIreference/APItypes.html#mjtjoint
_MUJOCO_JOINT_TYPE_DIMS = {
mujoco.mjtJoint.mjJNT_FREE: 7,
mujoco.mjtJoint.mjJNT_BALL: 4,
mujoco.mjtJoint.mjJNT_SLIDE: 1,
mujoco.mjtJoint.mjJNT_HINGE: 1,
}


def _align_joint_dims(types, ranges, names):
"""Creates bounds and joint names aligned with qpos dimensions."""
lb = []
ub = []
part_names = []
for type, range, name in zip(types, ranges, names):
dims = _MUJOCO_JOINT_TYPE_DIMS[type]
# Set inf bounds for freejoint
if type == mujoco.mjtJoint.mjJNT_FREE:
lb.append(_ROOT_QPOS_LB)
ub.append(_ROOT_QPOS_UB)
part_names += [name] * dims
else:
lb.append(range[0] * jp.ones(dims))
ub.append(range[1] * jp.ones(dims))
part_names += [name] * dims

# Prepend this to list of part names for one-to-one correspondence with qpos
_ROOT_NAMES = ["root"] * 6
return jp.minimum(jp.concatenate(lb), 0.0), jp.concatenate(ub), part_names


class STAC:
Expand All @@ -51,37 +76,38 @@ def __init__(
self._kp_names = kp_names
self._root = mjcf.from_path(xml_path)
(
mj_model,
self._mj_model,
self._body_site_idxs,
self._is_regularized,
self._part_names,
self._body_names,
) = self._create_body_sites(self._root)

self._body_names = [
self._mj_model.body(i).name for i in range(self._mj_model.nbody)
]

joint_names = [self._mj_model.joint(i).name for i in range(self._mj_model.njnt)]

# Set up bounds and part_names based on joint ranges, taking into account the dimensionality of parameters
self._lb, self._ub, self._part_names = _align_joint_dims(
self._mj_model.jnt_type, self._mj_model.jnt_range, joint_names
)

self._indiv_parts = self.part_opt_setup()

self._trunk_kps = jp.array(
[n in self.model_cfg["TRUNK_OPTIMIZATION_KEYPOINTS"] for n in kp_names],
)

mj_model.opt.solver = {
self._mj_model.opt.solver = {
"cg": mujoco.mjtSolver.mjSOL_CG,
"newton": mujoco.mjtSolver.mjSOL_NEWTON,
}[stac_cfg.mujoco.solver.lower()]

mj_model.opt.iterations = stac_cfg.mujoco.iterations
mj_model.opt.ls_iterations = stac_cfg.mujoco.ls_iterations
self._mj_model.opt.iterations = stac_cfg.mujoco.iterations
self._mj_model.opt.ls_iterations = stac_cfg.mujoco.ls_iterations

# Runs faster on GPU with this
mj_model.opt.jacobian = 0 # dense

self._mj_model = mj_model

# Set joint bounds
self._lb = jp.minimum(
jp.concatenate([_ROOT_QPOS_LB, self._mj_model.jnt_range[1:][:, 0]]),
0.0,
)
self._ub = jp.concatenate([_ROOT_QPOS_UB, self._mj_model.jnt_range[1:][:, 1]])
self._mj_model.opt.jacobian = 0 # dense

def part_opt_setup(self):
"""Set up the lists of indices for part optimization.
Expand Down Expand Up @@ -142,9 +168,6 @@ def _create_body_sites(self, root: mjcf.Element):
key: int(axis.convert_key_item(key))
for key in self.model_cfg["KEYPOINT_MODEL_PAIRS"].keys()
}
part_names = _ROOT_NAMES + physics.named.data.qpos.axes.row.names

body_names = physics.named.data.xpos.axes.row.names

# Define which offsets to regularize
is_regularized = []
Expand All @@ -160,8 +183,6 @@ def _create_body_sites(self, root: mjcf.Element):
physics.model.ptr,
jp.array(list(site_index_map.values())),
is_regularized,
part_names,
body_names,
)

def _chunk_kp_data(self, kp_data):
Expand Down
71 changes: 70 additions & 1 deletion tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from stac_mjx import main
from stac_mjx import utils
from stac_mjx.controller import STAC
from stac_mjx.controller import STAC, _align_joint_dims
from mujoco import _structs

_BASE_PATH = Path.cwd()
Expand All @@ -25,3 +25,72 @@ def test_init_stac(mocap_nwb, stac_config, rodent_config):
assert stac.model_cfg == model_cfg
assert stac._kp_names == sorted_kp_names
assert isinstance(stac._mj_model, _structs.MjModel)


def test_align_joint_dims():
from jax import numpy as jp
import mujoco

joint_types = [
mujoco.mjtJoint.mjJNT_FREE,
mujoco.mjtJoint.mjJNT_HINGE,
mujoco.mjtJoint.mjJNT_BALL,
mujoco.mjtJoint.mjJNT_SLIDE,
]
ranges = [[0.0, 0.0], [-0.1, 0.1], [0.0, 1.0], [-0.5, 0.5]]
names = ["root", "hingejoint", "balljoint", "slidejoint"]
lb, ub, part_names = _align_joint_dims(joint_types, ranges, names)
print(lb)

true_lb = jp.array(
[
-jp.inf,
-jp.inf,
-jp.inf,
-1.0,
-1.0,
-1.0,
-1.0,
-0.1,
0.0,
0.0,
0.0,
0.0,
-0.5,
]
)

true_ub = jp.array(
[
jp.inf,
jp.inf,
jp.inf,
1.0,
1.0,
1.0,
1.0,
0.1,
1.0,
1.0,
1.0,
1.0,
0.5,
]
)
assert jp.array_equal(lb, true_lb)
assert jp.array_equal(ub, true_ub)
assert part_names == [
"root",
"root",
"root",
"root",
"root",
"root",
"root",
"hingejoint",
"balljoint",
"balljoint",
"balljoint",
"balljoint",
"slidejoint",
]

0 comments on commit 7fef98f

Please sign in to comment.