Skip to content

Commit 943eb6b

Browse files
committed
rollout accepts a list of models of length nroll
1 parent 6e129a5 commit 943eb6b

File tree

5 files changed

+102
-43
lines changed

5 files changed

+102
-43
lines changed

doc/changelog.rst

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Python bindings
1616
- Added ``bind`` method and removed id attribute from :ref:`mjSpec` objects. Using ids is error prone in scenarios of repeated attachment and
1717
detachment. Python users are encouraged to use names for unique identification of model elements.
1818
- Removed ``nroll`` argument from :ref:`rollout<PyRollout>` because its value can always be inferred.
19+
- :ref:`rollout<PyRollout>` can now accept lists of MjModel of length ``nroll``.
1920

2021
Bug fixes
2122
^^^^^^^^^

doc/python.rst

+1
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,7 @@ states and sensor values. The basic usage form is
700700
701701
state, sensordata = rollout.rollout(model, data, initial_state, control)
702702
703+
``model`` is either a single instance of MjModel or a list of compatible MjModel of length ``nroll``.
703704
``initial_state`` is an ``nroll x nstate`` array, with ``nroll`` initial states of size ``nstate``, where
704705
``nstate = mj_stateSize(model, mjtState.mjSTATE_FULLPHYSICS)`` is the size of the
705706
:ref:`full physics state<geFullPhysics>`. ``control`` is a ``nroll x nstep x ncontrol`` array of controls. Controls are

python/mujoco/rollout.cc

+29-26
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ const auto rollout_doc = R"(
3737
Roll out open-loop trajectories from initial states, get resulting states and sensor values.
3838
3939
input arguments (required):
40-
model instance of MjModel
40+
model list of MjModel instances of length nroll
4141
data associated instance of MjData
4242
nstep integer, number of steps to be taken for each trajectory
4343
control_spec specification of controls, ncontrol = mj_stateSize(m, control_spec)
@@ -54,18 +54,18 @@ Roll out open-loop trajectories from initial states, get resulting states and se
5454
// C-style rollout function, assumes all arguments are valid
5555
// all input fields of d are initialised, contents at call time do not matter
5656
// after returning, d will contain the last step of the last rollout
57-
void _unsafe_rollout(const mjModel* m, mjData* d, int nroll, int nstep, unsigned int control_spec,
57+
void _unsafe_rollout(const mjModel** m, mjData* d, int nroll, int nstep, unsigned int control_spec,
5858
const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control,
5959
mjtNum* state, mjtNum* sensordata) {
6060
// sizes
61-
int nstate = mj_stateSize(m, mjSTATE_FULLPHYSICS);
62-
int ncontrol = mj_stateSize(m, control_spec);
63-
int nv = m->nv, nbody = m->nbody, neq = m->neq;
64-
int nsensordata = m->nsensordata;
61+
int nstate = mj_stateSize(m[0], mjSTATE_FULLPHYSICS);
62+
int ncontrol = mj_stateSize(m[0], control_spec);
63+
int nv = m[0]->nv, nbody = m[0]->nbody, neq = m[0]->neq;
64+
int nsensordata = m[0]->nsensordata;
6565

6666
// clear user inputs if unspecified
6767
if (!(control_spec & mjSTATE_CTRL)) {
68-
mju_zero(d->ctrl, m->nu);
68+
mju_zero(d->ctrl, m[0]->nu);
6969
}
7070
if (!(control_spec & mjSTATE_QFRC_APPLIED)) {
7171
mju_zero(d->qfrc_applied, nv);
@@ -75,26 +75,26 @@ void _unsafe_rollout(const mjModel* m, mjData* d, int nroll, int nstep, unsigned
7575
}
7676
if (!(control_spec & mjSTATE_MOCAP_POS)) {
7777
for (int i = 0; i < nbody; i++) {
78-
int id = m->body_mocapid[i];
79-
if (id >= 0) mju_copy3(d->mocap_pos+3*id, m->body_pos+3*i);
78+
int id = m[0]->body_mocapid[i];
79+
if (id >= 0) mju_copy3(d->mocap_pos+3*id, m[0]->body_pos+3*i);
8080
}
8181
}
8282
if (!(control_spec & mjSTATE_MOCAP_QUAT)) {
8383
for (int i = 0; i < nbody; i++) {
84-
int id = m->body_mocapid[i];
85-
if (id >= 0) mju_copy4(d->mocap_quat+4*id, m->body_quat+4*i);
84+
int id = m[0]->body_mocapid[i];
85+
if (id >= 0) mju_copy4(d->mocap_quat+4*id, m[0]->body_quat+4*i);
8686
}
8787
}
8888
if (!(control_spec & mjSTATE_EQ_ACTIVE)) {
8989
for (int i = 0; i < neq; i++) {
90-
d->eq_active[i] = m->eq_active0[i];
90+
d->eq_active[i] = m[0]->eq_active0[i];
9191
}
9292
}
9393

9494
// loop over rollouts
9595
for (int r = 0; r < nroll; r++) {
9696
// set initial state
97-
mj_setState(m, d, state0 + r*nstate, mjSTATE_FULLPHYSICS);
97+
mj_setState(m[r], d, state0 + r*nstate, mjSTATE_FULLPHYSICS);
9898

9999
// set warmstart accelerations
100100
if (warmstart0) {
@@ -124,7 +124,7 @@ void _unsafe_rollout(const mjModel* m, mjData* d, int nroll, int nstep, unsigned
124124
for (; t < nstep; t++) {
125125
int step = r*nstep + t;
126126
if (state) {
127-
mj_getState(m, d, state + step*nstate, mjSTATE_FULLPHYSICS);
127+
mj_getState(m[r], d, state + step*nstate, mjSTATE_FULLPHYSICS);
128128
}
129129
if (sensordata) {
130130
mju_copy(sensordata + step*nsensordata, d->sensordata, nsensordata);
@@ -137,15 +137,15 @@ void _unsafe_rollout(const mjModel* m, mjData* d, int nroll, int nstep, unsigned
137137

138138
// controls
139139
if (control) {
140-
mj_setState(m, d, control + step*ncontrol, control_spec);
140+
mj_setState(m[r], d, control + step*ncontrol, control_spec);
141141
}
142142

143143
// step
144-
mj_step(m, d);
144+
mj_step(m[r], d);
145145

146146
// copy out new state
147147
if (state) {
148-
mj_getState(m, d, state + step*nstate, mjSTATE_FULLPHYSICS);
148+
mj_getState(m[r], d, state + step*nstate, mjSTATE_FULLPHYSICS);
149149
}
150150

151151
// copy out sensor values
@@ -188,15 +188,20 @@ PYBIND11_MODULE(_rollout, pymodule) {
188188
// get subsequent states and corresponding sensor values
189189
pymodule.def(
190190
"rollout",
191-
[](const MjModelWrapper& m, MjDataWrapper& d,
191+
[](py::list m, MjDataWrapper& d,
192192
int nstep, unsigned int control_spec,
193193
const PyCArray state0,
194194
std::optional<const PyCArray> warmstart0,
195195
std::optional<const PyCArray> control,
196196
std::optional<const PyCArray> state,
197197
std::optional<const PyCArray> sensordata
198198
) {
199-
const raw::MjModel* model = m.get();
199+
// get raw pointers
200+
int nroll = state0.shape(0);
201+
const raw::MjModel* model_ptrs[nroll];
202+
for (int r = 0; r < nroll; r++) {
203+
model_ptrs[r] = m[r].cast<const MjModelWrapper*>()->get();
204+
}
200205
raw::MjData* data = d.get();
201206

202207
// check that some steps need to be taken, return if not
@@ -205,19 +210,17 @@ PYBIND11_MODULE(_rollout, pymodule) {
205210
}
206211

207212
// get sizes
208-
int nstate = mj_stateSize(model, mjSTATE_FULLPHYSICS);
209-
int ncontrol = mj_stateSize(model, control_spec);
210-
int nroll = state0.shape(0);
213+
int nstate = mj_stateSize(model_ptrs[0], mjSTATE_FULLPHYSICS);
214+
int ncontrol = mj_stateSize(model_ptrs[0], control_spec);
211215

212-
// get raw pointers
213216
mjtNum* state0_ptr = get_array_ptr(state0, "state0", nroll, 1, nstate);
214217
mjtNum* warmstart0_ptr = get_array_ptr(warmstart0, "warmstart0", nroll,
215-
1, model->nv);
218+
1, model_ptrs[0]->nv);
216219
mjtNum* control_ptr = get_array_ptr(control, "control", nroll,
217220
nstep, ncontrol);
218221
mjtNum* state_ptr = get_array_ptr(state, "state", nroll, nstep, nstate);
219222
mjtNum* sensordata_ptr = get_array_ptr(sensordata, "sensordata", nroll,
220-
nstep, model->nsensordata);
223+
nstep, model_ptrs[0]->nsensordata);
221224

222225
// perform rollouts
223226
{
@@ -226,7 +229,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
226229

227230
// call unsafe rollout function
228231
InterceptMjErrors(_unsafe_rollout)(
229-
model, data, nroll, nstep, control_spec, state0_ptr,
232+
model_ptrs, data, nroll, nstep, control_spec, state0_ptr,
230233
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr);
231234
}
232235
},

python/mujoco/rollout.py

+34-13
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
# ==============================================================================
1515
"""Roll out open-loop trajectories from initial states, get subsequent states and sensor values."""
1616

17-
from typing import Optional
17+
from typing import Optional, Union
1818

1919
import mujoco
2020
from mujoco import _rollout
2121
import numpy as np
2222
from numpy import typing as npt
2323

2424

25-
def rollout(model: mujoco.MjModel,
25+
def rollout(model: Union[mujoco.MjModel, list[mujoco.MjModel]],
2626
data: mujoco.MjData,
2727
initial_state: npt.ArrayLike,
2828
control: Optional[npt.ArrayLike] = None,
@@ -41,7 +41,7 @@ def rollout(model: mujoco.MjModel,
4141
Allocates outputs if none are given.
4242
4343
Args:
44-
model: An mjModel instance.
44+
model: An mjModel or a list of MjModel with the same size signature.
4545
data: An associated mjData instance.
4646
initial_state: Array of initial states from which to roll out trajectories.
4747
([nroll or 1] x nstate)
@@ -90,6 +90,7 @@ def rollout(model: mujoco.MjModel,
9090
state=state,
9191
sensordata=sensordata)
9292

93+
9394
# check number of dimensions
9495
_check_number_of_dimensions(2,
9596
initial_state=initial_state,
@@ -108,29 +109,49 @@ def rollout(model: mujoco.MjModel,
108109
state = _ensure_3d(state)
109110
sensordata = _ensure_3d(sensordata)
110111

111-
# check trailing dimensions
112-
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS.value)
113-
_check_trailing_dimension(nstate, initial_state=initial_state, state=state)
114-
ncontrol = mujoco.mj_stateSize(model, control_spec)
115-
_check_trailing_dimension(ncontrol, control=control)
116-
_check_trailing_dimension(model.nv, initial_warmstart=initial_warmstart)
117-
_check_trailing_dimension(model.nsensordata, sensordata=sensordata)
118-
119112
# infer nroll, check for incompatibilities
120113
nroll = _infer_dimension(0, 1,
121114
initial_state=initial_state,
122115
initial_warmstart=initial_warmstart,
123116
control=control,
124117
state=state,
125118
sensordata=sensordata)
119+
if isinstance(model, list) and nroll == 1:
120+
nroll = len(model)
121+
122+
if isinstance(model, list) and len(model) != nroll:
123+
raise ValueError(f'nroll inferred as {nroll} '
124+
f'but model is length {len(model)}')
125+
elif not isinstance(model, list):
126+
model = [model] # Use a length 1 list to simplify code below
126127

127128
# infer nstep, check for incompatibilities
128129
nstep = _infer_dimension(1, nstep or 1,
129130
control=control,
130131
state=state,
131132
sensordata=sensordata)
132133

133-
# tile input arrays if required (singleton expansion)
134+
# get nstate/ncontrol/nv/nsensordata
135+
# check that they are equal across models
136+
nstate = mujoco.mj_stateSize(model[0], mujoco.mjtState.mjSTATE_FULLPHYSICS.value)
137+
ncontrol = mujoco.mj_stateSize(model[0], control_spec)
138+
nv = model[0].nv
139+
nsensordata = model[0].nsensordata
140+
for m in model[1:]:
141+
if (nstate != mujoco.mj_stateSize(m, mujoco.mjtState.mjSTATE_FULLPHYSICS.value)
142+
or ncontrol != mujoco.mj_stateSize(m, control_spec)
143+
or nv != m.nv
144+
or nsensordata != m.nsensordata):
145+
raise ValueError('models are not compatible')
146+
147+
# check trailing dimensions
148+
_check_trailing_dimension(nstate, initial_state=initial_state, state=state)
149+
_check_trailing_dimension(ncontrol, control=control)
150+
_check_trailing_dimension(nv, initial_warmstart=initial_warmstart)
151+
_check_trailing_dimension(nsensordata, sensordata=sensordata)
152+
153+
# tile input arrays/lists if required (singleton expansion)
154+
model = model*nroll if len(model) == 1 else model
134155
initial_state = _tile_if_required(initial_state, nroll)
135156
initial_warmstart = _tile_if_required(initial_warmstart, nroll)
136157
control = _tile_if_required(control, nroll, nstep)
@@ -139,7 +160,7 @@ def rollout(model: mujoco.MjModel,
139160
if state is None:
140161
state = np.empty((nroll, nstep, nstate))
141162
if sensordata is None:
142-
sensordata = np.empty((nroll, nstep, model.nsensordata))
163+
sensordata = np.empty((nroll, nstep, nsensordata))
143164

144165
# call rollout
145166
_rollout.rollout(model, data, nstep, control_spec, initial_state,

python/mujoco/rollout_test.py

+37-4
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,34 @@ def test_multi_rollout(self, model_name):
334334
np.testing.assert_array_equal(state, py_state)
335335
np.testing.assert_array_equal(sensordata, py_sensordata)
336336

337+
@parameterized.parameters(ALL_MODELS.keys())
338+
def test_multi_model(self, model_name):
339+
nroll = 3 # number of initial states and models
340+
nstep = 3 # number of timesteps
341+
342+
spec = mujoco.MjSpec.from_string(ALL_MODELS[model_name])
343+
344+
if len(spec.bodies) > 1:
345+
model = []
346+
for i in range(nroll):
347+
body = spec.bodies[1]
348+
assert body.name != 'world'
349+
body.pos = body.pos + i
350+
model.append(spec.compile())
351+
else:
352+
model = [spec.compile() for i in range(nroll)]
353+
354+
nstate = mujoco.mj_stateSize(model[0], mujoco.mjtState.mjSTATE_FULLPHYSICS)
355+
data = mujoco.MjData(model[0])
356+
357+
initial_state = np.random.randn(nroll, nstate)
358+
control = np.random.randn(nroll, nstep, model[0].nu)
359+
state, sensordata = rollout.rollout(model, data, initial_state, control)
360+
361+
py_state, py_sensordata = py_rollout(model, data, initial_state, control)
362+
np.testing.assert_array_equal(state, py_state)
363+
np.testing.assert_array_equal(sensordata, py_sensordata)
364+
337365
@parameterized.parameters(ALL_MODELS.keys())
338366
def test_multi_rollout_fixed_ctrl_infer_from_output(self, model_name):
339367
model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name])
@@ -430,8 +458,9 @@ def test_threading(self):
430458
def thread_initializer():
431459
thread_local.data = mujoco.MjData(model)
432460

461+
model_list = [model]*nroll
433462
def call_rollout(initial_state, control, state, sensordata):
434-
rollout.rollout(model, thread_local.data, initial_state, control,
463+
rollout.rollout(model_list, thread_local.data, initial_state, control,
435464
skip_checks=True,
436465
nstep=nstep, state=state, sensordata=sensordata)
437466

@@ -677,13 +706,17 @@ def py_rollout(model, data, initial_state, control,
677706
control = ensure_3d(control)
678707
nroll = initial_state.shape[0]
679708
nstep = control.shape[1]
680-
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
709+
710+
if isinstance(model, mujoco.MjModel):
711+
model = [model]*nroll
712+
713+
nstate = mujoco.mj_stateSize(model[0], mujoco.mjtState.mjSTATE_FULLPHYSICS)
681714

682715
state = np.empty((nroll, nstep, nstate))
683-
sensordata = np.empty((nroll, nstep, model.nsensordata))
716+
sensordata = np.empty((nroll, nstep, model[0].nsensordata))
684717
for r in range(nroll):
685718
state_r, sensordata_r = one_rollout(
686-
model, data, initial_state[r], control[r], control_spec
719+
model[r], data, initial_state[r], control[r], control_spec
687720
)
688721
state[r] = state_r
689722
sensordata[r] = sensordata_r

0 commit comments

Comments
 (0)