Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predictive Control API #48

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions ambersim/control/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import jax
from flax import struct


@struct.dataclass
class ControllerParams:
"""The parameters for generic controllers.

This is left completely empty for maximum flexibility in the API. Some examples:
- "Regular" inputs into feedback controllers (e.g., the state) belong here.
- Non-Markovian controllers can pass histories in this params object.
- Parameters of the controller that you may randomize/optimize go here.
"""


@struct.dataclass
class Controller:
"""The API for a generic controller.

See the notes in TrajectoryOptimizer on the generality of this class - much of the same applies.
"""

def compute(self, ctrl_params: ControllerParams) -> jax.Array:
"""Computes a control input.

Args:
ctrl_params: ControllerParams

Returns:
u (shape=(nu,)): The control input.
"""
raise NotImplementedError
149 changes: 149 additions & 0 deletions ambersim/control/predictive_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from typing import Tuple

import jax
import jax.numpy as jnp
from flax import struct
from mujoco import mjx

from ambersim.control.base import Controller, ControllerParams
from ambersim.trajopt.base import TrajectoryOptimizer
from ambersim.trajopt.shooting import PDPredictiveSamplerParams, PredictiveSampler, VanillaPredictiveSamplerParams

# ########### #
# GENERIC API #
# ########### #


@struct.dataclass
class PredictiveControllerParams(ControllerParams):
"""The generic API for predictive controller params."""


@struct.dataclass
class PredictiveController(Controller):
"""The generic API for a predictive controller."""

trajectory_optimizer: TrajectoryOptimizer
model: mjx.Model

def compute(self, ctrl_params: PredictiveControllerParams) -> jax.Array:
"""Computes a control input using forward prediction."""
raise NotImplementedError


# ################### #
# PREDICTIVE SAMPLING #
# ################### #


@struct.dataclass
class VanillaPredictiveSamplingControllerParams(PredictiveControllerParams):
"""Vanilla predictive sampling controller params."""

key: jax.Array # random key for sampling
x: jax.Array # shape=(nq + nv,) current state
guess: jax.Array # shape=(N, nu) current guess


@struct.dataclass
class VanillaPredictiveSamplingController(PredictiveController):
"""Vanilla predictive sampling controller."""

def __post_init__(self) -> None:
"""Post-initialization check."""
assert isinstance(
self.trajectory_optimizer, PredictiveSampler
), "trajectory_optimizer must be a PredictiveSampler!"

def compute(self, ctrl_params: VanillaPredictiveSamplingControllerParams) -> jax.Array:
"""Computes a control input using forward prediction.

Args:
ctrl_params: Inputs into the controller.

Returns:
u (shape=(nu,)): The control input.
"""
return self.compute_with_us_star(ctrl_params)[0]

def compute_with_us_star(
self, ctrl_params: VanillaPredictiveSamplingControllerParams
) -> Tuple[jax.Array, jax.Array]:
"""Computes a control input using forward prediction + the optimal sequence of guesses.

This is needed in practice because the current optimal sequence is used to warm start the sampling distribution
for the next call of the controller.

Args:
ctrl_params: Inputs into the controller.

Returns:
u (shape=(nu,)): The control input.
us_star (shape=(N, nu)): The optimal control sequence.
"""
to_params = VanillaPredictiveSamplerParams(
key=ctrl_params.key,
x0=ctrl_params.x,
guess=ctrl_params.guess,
)
xs_star, us_star = self.trajectory_optimizer.optimize(to_params)
u = us_star[0, :]
return u, us_star


@struct.dataclass
class PDPredictiveSamplingControllerParams(PredictiveControllerParams):
"""PD predictive sampling controller params."""

key: jax.Array # random key for sampling
x: jax.Array # shape=(nq + nv,) current state
guess: jax.Array # shape=(N, nq) current guess
kp: float # proportional gain
kd: float # derivative gain


@struct.dataclass
class PDPredictiveSamplingController(PredictiveController):
"""PD predictive sampling controller."""

def __post_init__(self) -> None:
"""Post-initialization check."""
assert isinstance(
self.trajectory_optimizer, PredictiveSampler
), "trajectory_optimizer must be a PredictiveSampler!"

def compute(self, ctrl_params: PDPredictiveSamplingControllerParams) -> jax.Array:
"""Computes a control input using forward prediction.

Args:
ctrl_params: Inputs into the controller.

Returns:
u (shape=(nu,)): The control input.
"""
return self.compute_with_qs_star(ctrl_params)[0]

def compute_with_qs_star(self, ctrl_params: PDPredictiveSamplingControllerParams) -> Tuple[jax.Array, jax.Array]:
"""Computes a control input using forward prediction + the optimal sequence of guesses.

This is needed in practice because the current optimal sequence is used to warm start the sampling distribution
for the next call of the controller.

Args:
ctrl_params: Inputs into the controller.

Returns:
u (shape=(nu,)): The control input.
us_star (shape=(N, nu)): The optimal control sequence.
"""
to_params = PDPredictiveSamplerParams(
key=ctrl_params.key,
x0=ctrl_params.x,
guess=ctrl_params.guess,
kp=ctrl_params.kp,
kd=ctrl_params.kd,
)
xs_star, us_star = self.trajectory_optimizer.optimize(to_params)
# u = us_star[0, :]
qs_star = xs_star[:, : self.model.nq] # the 0th index is the current state, so return the 1st index
return qs_star[1, :], qs_star
Binary file not shown.
Binary file not shown.
Binary file added ambersim/models/allegro_hand/assets/fileback.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ambersim/models/allegro_hand/assets/filedown.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ambersim/models/allegro_hand/assets/filefront.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ambersim/models/allegro_hand/assets/fileleft.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ambersim/models/allegro_hand/assets/fileright.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ambersim/models/allegro_hand/assets/fileup.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ambersim/models/allegro_hand/assets/grayback.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ambersim/models/allegro_hand/assets/graydown.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ambersim/models/allegro_hand/assets/grayfront.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ambersim/models/allegro_hand/assets/grayleft.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ambersim/models/allegro_hand/assets/grayright.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ambersim/models/allegro_hand/assets/grayup.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ambersim/models/allegro_hand/assets/link_0.0.stl
Binary file not shown.
Binary file added ambersim/models/allegro_hand/assets/link_1.0.stl
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added ambersim/models/allegro_hand/assets/link_14.0.stl
Binary file not shown.
Binary file added ambersim/models/allegro_hand/assets/link_15.0.stl
Binary file not shown.
Binary file not shown.
Binary file added ambersim/models/allegro_hand/assets/link_2.0.stl
Binary file not shown.
Binary file added ambersim/models/allegro_hand/assets/link_3.0.stl
Binary file not shown.
Binary file not shown.
Binary file added ambersim/models/allegro_hand/assets/link_4.0.stl
Diff not rendered.
29 changes: 29 additions & 0 deletions ambersim/models/allegro_hand/cube.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
<mujoco>
<asset>
<texture name="cube" type="cube"
fileup="assets/fileup.png" fileback="assets/fileback.png"
filedown="assets/filedown.png" filefront="assets/filefront.png"
fileleft="assets/fileleft.png" fileright="assets/fileright.png"/>
<material name="cube" texture="cube"/>
<texture name="graycube" type="cube" fileup="assets/grayup.png"
fileback="assets/grayback.png" filedown="assets/graydown.png"
filefront="assets/grayfront.png" fileleft="assets/grayleft.png"
fileright="assets/grayright.png"/>
<material name="graycube" texture="graycube"/>
</asset>
<worldbody>
<light pos="0 0 1"/>
<body name="cube" pos="0.0 0.0 0.035" quat="1 0 0 0">
<freejoint/>
<geom name="cube" type="box" size=".022 .022 .022" mass=".126" material="cube"/>
</body>
</worldbody>

<sensor>
<framepos name="trace0" objtype="body" objname="cube"/>
<framepos name="cube_position" objtype="body" objname="cube"/>
<framequat name="cube_orientation" objtype="body" objname="cube"/>
<framelinvel name="cube_linear_velocity" objtype="body" objname="cube"/>
<frameangvel name="cube_angular_velocity" objtype="body" objname="cube"/>
</sensor>
</mujoco>
Loading