Skip to content

Commit 06b90fe

Browse files
committed
rollout add chunk_divisor parameter
1 parent 75603ee commit 06b90fe

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

doc/changelog.rst

+5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ Bug fixes
99
^^^^^^^^^
1010
- Fixed a bug in the box-sphere collider, depth was incorrect for deep penetrations (:github:issue:`2206`).
1111

12+
Python bindings
13+
^^^^^^^^^^^^^^^
14+
- :ref:`rollout<PyRollout>` can now accept sequences of MjData of length ``nthread``. If passed, :ref:`rollout<PyRollout>`
15+
will automatically create a persistent threadpool and parallelize rollouts.
16+
1217
Version 3.2.6 (Dec 2, 2024)
1318
---------------------------
1419

python/mujoco/rollout.cc

+7-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Roll out open-loop trajectories from initial states, get resulting states and se
3939
4040
input arguments (required):
4141
model list of MjModel instances of length nroll
42-
data associated instance of MjData
42+
data list of associated MjData instances of length nthread
4343
nstep integer, number of steps to be taken for each trajectory
4444
control_spec specification of controls, ncontrol = mj_stateSize(m, control_spec)
4545
state0 (nroll x nstate) nroll initial state vectors,
@@ -50,6 +50,8 @@ Roll out open-loop trajectories from initial states, get resulting states and se
5050
output arguments (optional):
5151
state (nroll x nstep x nstate) nroll nstep states
5252
sensordata (nroll x nstep x nsendordata) nroll trajectories of nstep sensordata vectors
53+
chunk_divisor integer, determines threadpool chunk size according to
54+
chunk_size = max(1, nroll / (nthread * chunk_divisor)
5355
)";
5456

5557
// C-style rollout function, assumes all arguments are valid
@@ -238,7 +240,8 @@ PYBIND11_MODULE(_rollout, pymodule) {
238240
std::optional<const PyCArray> warmstart0,
239241
std::optional<const PyCArray> control,
240242
std::optional<const PyCArray> state,
241-
std::optional<const PyCArray> sensordata
243+
std::optional<const PyCArray> sensordata,
244+
int chunk_divisor
242245
) {
243246
// get raw pointers
244247
int nroll = state0.shape(0);
@@ -278,7 +281,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
278281

279282
// call unsafe rollout function
280283
if (nthread > 1 && nroll > 1) {
281-
int chunk_size = std::max(1, nroll / (10 * nthread));
284+
int chunk_size = std::max(1, nroll / (chunk_divisor * nthread));
282285
InterceptMjErrors(_unsafe_rollout_threaded)(
283286
model_ptrs, data_ptrs, nroll, nstep, control_spec, state0_ptr,
284287
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr,
@@ -300,6 +303,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
300303
py::arg("control") = py::none(),
301304
py::arg("state") = py::none(),
302305
py::arg("sensordata") = py::none(),
306+
py::arg("chunk_divisor") = 10,
303307
py::doc(rollout_doc)
304308
);
305309
}

python/mujoco/rollout.py

+5
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def rollout(
3535
initial_warmstart: Optional[npt.ArrayLike] = None,
3636
state: Optional[npt.ArrayLike] = None,
3737
sensordata: Optional[npt.ArrayLike] = None,
38+
chunk_divisor: int = 10,
3839
):
3940
"""Rolls out open-loop trajectories from initial states, get subsequent states and sensor values.
4041
@@ -59,6 +60,8 @@ def rollout(
5960
(nroll x nstep x nstate)
6061
sensordata: Sensor data output array (optional).
6162
(nroll x nstep x nsensordata)
63+
chunk_divisor: Determines threadpool chunk size according to
64+
chunk_size = max(1, nroll / (nthread * chunk_divisor)
6265
6366
Returns:
6467
state:
@@ -85,6 +88,7 @@ def rollout(
8588
control,
8689
state,
8790
sensordata,
91+
chunk_divisor,
8892
)
8993
return state, sensordata
9094

@@ -198,6 +202,7 @@ def rollout(
198202
control,
199203
state,
200204
sensordata,
205+
chunk_divisor,
201206
)
202207

203208
# return outputs

0 commit comments

Comments
 (0)