Skip to content

Commit 169cf99

Browse files
committed
rollout exchange chunk_divisor arg for chunk_size
1 parent 298ab2f commit 169cf99

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

python/mujoco/rollout.cc

+12-6
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ Roll out open-loop trajectories from initial states, get resulting states and se
5050
output arguments (optional):
5151
state (nroll x nstep x nstate) nroll nstep states
5252
sensordata (nroll x nstep x nsendordata) nroll trajectories of nstep sensordata vectors
53-
chunk_divisor integer, determines threadpool chunk size according to
54-
chunk_size = max(1, nroll / (nthread * chunk_divisor)
53+
chunk_size integer, determines threadpool chunk size. If unspecified
54+
chunk_size = max(1, nroll / (nthread * 10)
5555
)";
5656

5757
// C-style rollout function, assumes all arguments are valid
@@ -241,7 +241,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
241241
std::optional<const PyCArray> control,
242242
std::optional<const PyCArray> state,
243243
std::optional<const PyCArray> sensordata,
244-
int chunk_divisor
244+
std::optional<int> chunk_size
245245
) {
246246
// get raw pointers
247247
int nroll = state0.shape(0);
@@ -281,11 +281,17 @@ PYBIND11_MODULE(_rollout, pymodule) {
281281

282282
// call unsafe rollout function
283283
if (nthread > 1 && nroll > 1) {
284-
int chunk_size = std::max(1, nroll / (chunk_divisor * nthread));
284+
int chunk_size_final = 1;
285+
if (!chunk_size.has_value()) {
286+
chunk_size_final = std::max(1, nroll / (10 * nthread));
287+
}
288+
else {
289+
chunk_size_final = *chunk_size;
290+
}
285291
InterceptMjErrors(_unsafe_rollout_threaded)(
286292
model_ptrs, data_ptrs, nroll, nstep, control_spec, state0_ptr,
287293
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr,
288-
nthread, chunk_size);
294+
nthread, chunk_size_final);
289295
}
290296
else {
291297
InterceptMjErrors(_unsafe_rollout)(
@@ -303,7 +309,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
303309
py::arg("control") = py::none(),
304310
py::arg("state") = py::none(),
305311
py::arg("sensordata") = py::none(),
306-
py::arg("chunk_divisor") = 10,
312+
py::arg("chunk_size") = py::none(),
307313
py::doc(rollout_doc)
308314
);
309315
}

python/mujoco/rollout.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def rollout(
3535
initial_warmstart: Optional[npt.ArrayLike] = None,
3636
state: Optional[npt.ArrayLike] = None,
3737
sensordata: Optional[npt.ArrayLike] = None,
38-
chunk_divisor: int = 10,
38+
chunk_size: int = None,
3939
):
4040
"""Rolls out open-loop trajectories from initial states, get subsequent states and sensor values.
4141
@@ -60,8 +60,8 @@ def rollout(
6060
(nroll x nstep x nstate)
6161
sensordata: Sensor data output array (optional).
6262
(nroll x nstep x nsensordata)
63-
chunk_divisor: Determines threadpool chunk size according to
64-
chunk_size = max(1, nroll / (nthread * chunk_divisor)
63+
chunk_size: Determines threadpool chunk size. If unspecified,
64+
chunk_size = max(1, nroll / (nthread * 10)
6565
6666
Returns:
6767
state:
@@ -88,7 +88,7 @@ def rollout(
8888
control,
8989
state,
9090
sensordata,
91-
chunk_divisor,
91+
chunk_size,
9292
)
9393
return state, sensordata
9494

@@ -102,6 +102,8 @@ def rollout(
102102
# check types
103103
if nstep and not isinstance(nstep, int):
104104
raise ValueError('nstep must be an integer')
105+
if chunk_size and not isinstance(chunk_size, int):
106+
raise ValueError('chunk_size must be an integer')
105107
_check_must_be_numeric(
106108
initial_state=initial_state,
107109
initial_warmstart=initial_warmstart,
@@ -202,7 +204,7 @@ def rollout(
202204
control,
203205
state,
204206
sensordata,
205-
chunk_divisor,
207+
chunk_size,
206208
)
207209

208210
# return outputs

0 commit comments

Comments
 (0)