@@ -39,7 +39,7 @@ Roll out open-loop trajectories from initial states, get resulting states and se
39
39
40
40
input arguments (required):
41
41
model list of MjModel instances of length nroll
42
- data associated instance of MjData
42
+ data list of associated MjData instances of length nthread
43
43
nstep integer, number of steps to be taken for each trajectory
44
44
control_spec specification of controls, ncontrol = mj_stateSize(m, control_spec)
45
45
state0 (nroll x nstate) nroll initial state vectors,
@@ -50,6 +50,8 @@ Roll out open-loop trajectories from initial states, get resulting states and se
50
50
output arguments (optional):
51
51
state (nroll x nstep x nstate) nroll nstep states
52
52
sensordata (nroll x nstep x nsendordata) nroll trajectories of nstep sensordata vectors
53
+ chunk_divisor integer, determines threadpool chunk size according to
54
+ chunk_size = max(1, nroll / (nthread * chunk_divisor)
53
55
)" ;
54
56
55
57
// C-style rollout function, assumes all arguments are valid
@@ -238,7 +240,8 @@ PYBIND11_MODULE(_rollout, pymodule) {
238
240
std::optional<const PyCArray> warmstart0,
239
241
std::optional<const PyCArray> control,
240
242
std::optional<const PyCArray> state,
241
- std::optional<const PyCArray> sensordata
243
+ std::optional<const PyCArray> sensordata,
244
+ int chunk_divisor
242
245
) {
243
246
// get raw pointers
244
247
int nroll = state0.shape (0 );
@@ -278,7 +281,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
278
281
279
282
// call unsafe rollout function
280
283
if (nthread > 1 && nroll > 1 ) {
281
- int chunk_size = std::max (1 , nroll / (10 * nthread));
284
+ int chunk_size = std::max (1 , nroll / (chunk_divisor * nthread));
282
285
InterceptMjErrors (_unsafe_rollout_threaded)(
283
286
model_ptrs, data_ptrs, nroll, nstep, control_spec, state0_ptr,
284
287
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr,
@@ -300,6 +303,7 @@ PYBIND11_MODULE(_rollout, pymodule) {
300
303
py::arg (" control" ) = py::none (),
301
304
py::arg (" state" ) = py::none (),
302
305
py::arg (" sensordata" ) = py::none (),
306
+ py::arg (" chunk_divisor" ) = 10 ,
303
307
py::doc (rollout_doc)
304
308
);
305
309
}
0 commit comments